test_data_repositories.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright 2023 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import pytest
  13. from taipy.core.data._data_fs_repository import _DataFSRepository
  14. from taipy.core.data._data_sql_repository import _DataSQLRepository
  15. from taipy.core.data.data_node import DataNode, DataNodeId
  16. from taipy.core.exceptions import ModelNotFound
  17. class TestDataNodeRepository:
  18. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  19. def test_save_and_load(self, data_node, repo, init_sql_repo):
  20. repository = repo()
  21. repository._save(data_node)
  22. obj = repository._load(data_node.id)
  23. assert isinstance(obj, DataNode)
  24. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  25. def test_exists(self, data_node, repo, init_sql_repo):
  26. repository = repo()
  27. repository._save(data_node)
  28. assert repository._exists(data_node.id)
  29. assert not repository._exists("not-existed-data-node")
  30. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  31. def test_load_all(self, data_node, repo, init_sql_repo):
  32. repository = repo()
  33. for i in range(10):
  34. data_node.id = DataNodeId(f"data_node-{i}")
  35. repository._save(data_node)
  36. data_nodes = repository._load_all()
  37. assert len(data_nodes) == 10
  38. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  39. def test_load_all_with_filters(self, data_node, repo, init_sql_repo):
  40. repository = repo()
  41. for i in range(10):
  42. data_node.id = DataNodeId(f"data_node-{i}")
  43. data_node._owner_id = f"task-{i}"
  44. repository._save(data_node)
  45. objs = repository._load_all(filters=[{"owner_id": "task-2"}])
  46. assert len(objs) == 1
  47. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  48. def test_delete(self, data_node, repo, init_sql_repo):
  49. repository = repo()
  50. repository._save(data_node)
  51. repository._delete(data_node.id)
  52. with pytest.raises(ModelNotFound):
  53. repository._load(data_node.id)
  54. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  55. def test_delete_all(self, data_node, repo, init_sql_repo):
  56. repository = repo()
  57. for i in range(10):
  58. data_node.id = DataNodeId(f"data_node-{i}")
  59. repository._save(data_node)
  60. assert len(repository._load_all()) == 10
  61. repository._delete_all()
  62. assert len(repository._load_all()) == 0
  63. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  64. def test_delete_many(self, data_node, repo, init_sql_repo):
  65. repository = repo()
  66. for i in range(10):
  67. data_node.id = DataNodeId(f"data_node-{i}")
  68. repository._save(data_node)
  69. objs = repository._load_all()
  70. assert len(objs) == 10
  71. ids = [x.id for x in objs[:3]]
  72. repository._delete_many(ids)
  73. assert len(repository._load_all()) == 7
  74. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  75. def test_delete_by(self, data_node, repo, init_sql_repo):
  76. repository = repo()
  77. # Create 5 entities with version 1.0 and 5 entities with version 2.0
  78. for i in range(10):
  79. data_node.id = DataNodeId(f"data_node-{i}")
  80. data_node._version = f"{(i+1) // 5}.0"
  81. repository._save(data_node)
  82. objs = repository._load_all()
  83. assert len(objs) == 10
  84. repository._delete_by("version", "1.0")
  85. assert len(repository._load_all()) == 5
  86. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  87. def test_search(self, data_node, repo, init_sql_repo):
  88. repository = repo()
  89. for i in range(10):
  90. data_node.id = DataNodeId(f"data_node-{i}")
  91. data_node._owner_id = f"task-{i}"
  92. repository._save(data_node)
  93. assert len(repository._load_all()) == 10
  94. objs = repository._search("owner_id", "task-2")
  95. assert len(objs) == 1
  96. assert isinstance(objs[0], DataNode)
  97. objs = repository._search("owner_id", "task-2", filters=[{"version": "random_version_number"}])
  98. assert len(objs) == 1
  99. assert isinstance(objs[0], DataNode)
  100. assert repository._search("owner_id", "task-2", filters=[{"version": "non_existed_version"}]) == []
  101. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  102. def test_export(self, tmpdir, data_node, repo, init_sql_repo):
  103. repository = repo()
  104. repository._save(data_node)
  105. repository._export(data_node.id, tmpdir.strpath)
  106. dir_path = repository.dir_path if repo == _DataFSRepository else os.path.join(tmpdir.strpath, "data_node")
  107. assert os.path.exists(os.path.join(dir_path, f"{data_node.id}.json"))