test_data_manager.py 37 KB


  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import pathlib
  13. import pandas as pd
  14. import pytest
  15. from pandas.testing import assert_frame_equal
  16. from taipy import Scope
  17. from taipy.common.config import Config
  18. from taipy.core._version._version_manager import _VersionManager
  19. from taipy.core.config.data_node_config import DataNodeConfig
  20. from taipy.core.data._data_manager import _DataManager
  21. from taipy.core.data.csv import CSVDataNode
  22. from taipy.core.data.data_node_id import DataNodeId
  23. from taipy.core.data.in_memory import InMemoryDataNode
  24. from taipy.core.data.pickle import PickleDataNode
  25. from taipy.core.exceptions.exceptions import InvalidDataNodeType, ModelNotFound, NoData
  26. from taipy.core.reason import EntityDoesNotExist, NotGlobalScope, WrongConfigType
  27. from tests.core.utils.named_temporary_file import NamedTemporaryFile
  28. def file_exists(file_path: str) -> bool:
  29. return os.path.exists(file_path)
  30. class TestDataManager:
  31. def test_create_data_node_and_modify_properties_does_not_modify_config(self):
  32. dn_config = Config.configure_data_node(id="name", foo="bar")
  33. dn = _DataManager._create(dn_config, None, None)
  34. assert dn_config.properties.get("foo") == "bar"
  35. assert dn_config.properties.get("baz") is None
  36. dn.properties["baz"] = "qux"
  37. _DataManager._update(dn)
  38. assert dn_config.properties.get("foo") == "bar"
  39. assert dn_config.properties.get("baz") is None
  40. assert dn.properties.get("foo") == "bar"
  41. assert dn.properties.get("baz") == "qux"
  42. def test_can_create(self):
  43. dn_config = Config.configure_data_node("dn", 10, scope=Scope.SCENARIO)
  44. global_dn_config = Config.configure_data_node(
  45. id="global_dn", storage_type="in_memory", scope=Scope.GLOBAL, data=10
  46. )
  47. reasons = _DataManager._can_create()
  48. assert bool(reasons) is True
  49. assert reasons._reasons == {}
  50. reasons = _DataManager._can_create(global_dn_config)
  51. assert bool(reasons) is True
  52. assert reasons._reasons == {}
  53. reasons = _DataManager._can_create(dn_config)
  54. assert bool(reasons) is False
  55. assert reasons._reasons[dn_config.id] == {NotGlobalScope(dn_config.id)}
  56. assert (
  57. str(list(reasons._reasons[dn_config.id])[0])
  58. == f"Data node config '{dn_config.id}' does not have GLOBAL scope"
  59. )
  60. reasons = _DataManager._can_create(1)
  61. assert bool(reasons) is False
  62. assert reasons._reasons["1"] == {WrongConfigType("1", DataNodeConfig.__name__)}
  63. assert str(list(reasons._reasons["1"])[0]) == "Object '1' must be a valid DataNodeConfig"
  64. def test_create_data_node_with_name_provided(self):
  65. dn_config = Config.configure_data_node(id="dn", foo="bar", name="acb")
  66. dn = _DataManager._create(dn_config, None, None)
  67. assert dn.name == "acb"
  68. def test_create_and_get_csv_data_node(self):
  69. # Test we can instantiate a CsvDataNode from DataNodeConfig with :
  70. # - a csv type
  71. # - a default scenario scope
  72. # - No owner_id
  73. csv_dn_config = Config.configure_data_node(id="foo", storage_type="csv", path="bar", has_header=True)
  74. csv_dn = _DataManager._create(csv_dn_config, None, None)
  75. assert isinstance(csv_dn, CSVDataNode)
  76. assert isinstance(_DataManager._get(csv_dn.id), CSVDataNode)
  77. assert _DataManager._exists(csv_dn.id)
  78. assert _DataManager._get(csv_dn.id) is not None
  79. assert _DataManager._get(csv_dn.id).id == csv_dn.id
  80. assert _DataManager._get(csv_dn.id).config_id == "foo"
  81. assert _DataManager._get(csv_dn.id).config_id == csv_dn.config_id
  82. assert _DataManager._get(csv_dn.id).scope == Scope.SCENARIO
  83. assert _DataManager._get(csv_dn.id).scope == csv_dn.scope
  84. assert _DataManager._get(csv_dn.id).owner_id is None
  85. assert _DataManager._get(csv_dn.id).owner_id == csv_dn.owner_id
  86. assert _DataManager._get(csv_dn.id).parent_ids == set()
  87. assert _DataManager._get(csv_dn.id).parent_ids == csv_dn.parent_ids
  88. assert _DataManager._get(csv_dn.id).last_edit_date is None
  89. assert _DataManager._get(csv_dn.id).last_edit_date == csv_dn.last_edit_date
  90. assert _DataManager._get(csv_dn.id).job_ids == []
  91. assert _DataManager._get(csv_dn.id).job_ids == csv_dn.job_ids
  92. assert not _DataManager._get(csv_dn.id).is_ready_for_reading
  93. assert _DataManager._get(csv_dn.id).is_ready_for_reading == csv_dn.is_ready_for_reading
  94. assert (
  95. len(_DataManager._get(csv_dn.id).properties) == 6
  96. ) # path, encoding, separator, has_header, exposed_type, is_generated
  97. assert _DataManager._get(csv_dn.id).properties.get("path") == "bar"
  98. assert _DataManager._get(csv_dn.id).properties.get("encoding") == "utf-8"
  99. assert _DataManager._get(csv_dn.id).properties.get("separator") == ","
  100. assert _DataManager._get(csv_dn.id).properties.get("has_header") is True
  101. assert _DataManager._get(csv_dn.id).properties.get("exposed_type") == "pandas"
  102. assert _DataManager._get(csv_dn.id).properties.get("is_generated") is False
  103. assert _DataManager._get(csv_dn.id).properties == csv_dn.properties
  104. assert _DataManager._get(csv_dn.id).edit_in_progress is False
  105. assert _DataManager._get(csv_dn.id)._editor_id is None
  106. assert _DataManager._get(csv_dn.id)._editor_expiration_date is None
  107. assert _DataManager._get(csv_dn) is not None
  108. assert _DataManager._get(csv_dn).id == csv_dn.id
  109. assert _DataManager._get(csv_dn).config_id == "foo"
  110. assert _DataManager._get(csv_dn).config_id == csv_dn.config_id
  111. assert _DataManager._get(csv_dn).scope == Scope.SCENARIO
  112. assert _DataManager._get(csv_dn).scope == csv_dn.scope
  113. assert _DataManager._get(csv_dn).owner_id is None
  114. assert _DataManager._get(csv_dn).owner_id == csv_dn.owner_id
  115. assert _DataManager._get(csv_dn).parent_ids == set()
  116. assert _DataManager._get(csv_dn).parent_ids == csv_dn.parent_ids
  117. assert _DataManager._get(csv_dn).last_edit_date is None
  118. assert _DataManager._get(csv_dn).last_edit_date == csv_dn.last_edit_date
  119. assert _DataManager._get(csv_dn).job_ids == []
  120. assert _DataManager._get(csv_dn).job_ids == csv_dn.job_ids
  121. assert not _DataManager._get(csv_dn).is_ready_for_reading
  122. assert _DataManager._get(csv_dn).is_ready_for_reading == csv_dn.is_ready_for_reading
  123. assert len(_DataManager._get(csv_dn).properties) == 6
  124. assert _DataManager._get(csv_dn).properties.get("path") == "bar"
  125. assert _DataManager._get(csv_dn).properties.get("encoding") == "utf-8"
  126. assert _DataManager._get(csv_dn).properties.get("has_header") is True
  127. assert _DataManager._get(csv_dn.id).properties.get("exposed_type") == "pandas"
  128. assert _DataManager._get(csv_dn.id).properties.get("is_generated") is False
  129. assert _DataManager._get(csv_dn).properties == csv_dn.properties
  130. assert _DataManager._get(csv_dn.id).edit_in_progress is False
  131. assert _DataManager._get(csv_dn.id)._editor_id is None
  132. assert _DataManager._get(csv_dn.id)._editor_expiration_date is None
  133. def test_edit_and_get_data_node(self):
  134. config = Config.configure_pickle_data_node(id="foo")
  135. dn = _DataManager._create(config, None, None)
  136. assert _DataManager._get(dn.id).last_edit_date is None
  137. assert len(_DataManager._get(dn.id).properties) == 2 # is_generated and path
  138. assert isinstance(_DataManager._get(dn.id).properties.get("path"), str)
  139. assert _DataManager._get(dn.id).properties.get("is_generated") is True
  140. assert not _DataManager._get(dn.id).edit_in_progress
  141. assert _DataManager._get(dn.id)._editor_id is None
  142. assert _DataManager._get(dn.id)._editor_expiration_date is None
  143. dn.lock_edit("foo")
  144. assert _DataManager._get(dn.id).last_edit_date is None
  145. assert len(_DataManager._get(dn.id).properties) == 2 # is_generated and path
  146. assert isinstance(_DataManager._get(dn.id).properties.get("path"), str)
  147. assert _DataManager._get(dn.id).properties.get("is_generated") is True
  148. assert _DataManager._get(dn.id).edit_in_progress
  149. assert _DataManager._get(dn.id).editor_id == "foo"
  150. assert _DataManager._get(dn.id).editor_expiration_date is not None
  151. dn.unlock_edit("foo")
  152. assert _DataManager._get(dn.id).last_edit_date is None
  153. assert len(_DataManager._get(dn.id).properties) == 2 # is_generated and path
  154. assert isinstance(_DataManager._get(dn.id).properties.get("path"), str)
  155. assert _DataManager._get(dn.id).properties.get("is_generated") is True
  156. assert not _DataManager._get(dn.id).edit_in_progress
  157. assert _DataManager._get(dn.id).editor_id is None
  158. assert _DataManager._get(dn.id).editor_expiration_date is None
  159. def test_create_and_get_in_memory_data_node(self):
  160. # Test we can instantiate an InMemoryDataNode from DataNodeConfig with :
  161. # - an in_memory type
  162. # - a scenario scope
  163. # - an owner id
  164. # - some default data
  165. in_memory_dn_config = Config.configure_data_node(
  166. id="baz", storage_type="in_memory", scope=Scope.SCENARIO, default_data="qux", other_data="foo"
  167. )
  168. in_mem_dn = _DataManager._create(in_memory_dn_config, "Scenario_id", {"task_id"})
  169. assert isinstance(in_mem_dn, InMemoryDataNode)
  170. assert isinstance(_DataManager._get(in_mem_dn.id), InMemoryDataNode)
  171. assert _DataManager._exists(in_mem_dn.id)
  172. assert _DataManager._get(in_mem_dn.id) is not None
  173. assert _DataManager._get(in_mem_dn.id).id == in_mem_dn.id
  174. assert _DataManager._get(in_mem_dn.id).config_id == "baz"
  175. assert _DataManager._get(in_mem_dn.id).config_id == in_mem_dn.config_id
  176. assert _DataManager._get(in_mem_dn.id).scope == Scope.SCENARIO
  177. assert _DataManager._get(in_mem_dn.id).scope == in_mem_dn.scope
  178. assert _DataManager._get(in_mem_dn.id).owner_id == "Scenario_id"
  179. assert _DataManager._get(in_mem_dn.id).owner_id == in_mem_dn.owner_id
  180. assert _DataManager._get(in_mem_dn.id).parent_ids == {"task_id"}
  181. assert _DataManager._get(in_mem_dn.id).parent_ids == in_mem_dn.parent_ids
  182. assert _DataManager._get(in_mem_dn.id).last_edit_date is not None
  183. assert _DataManager._get(in_mem_dn.id).last_edit_date == in_mem_dn.last_edit_date
  184. assert _DataManager._get(in_mem_dn.id).job_ids == []
  185. assert _DataManager._get(in_mem_dn.id).job_ids == in_mem_dn.job_ids
  186. assert _DataManager._get(in_mem_dn.id).is_ready_for_reading
  187. assert _DataManager._get(in_mem_dn.id).is_ready_for_reading == in_mem_dn.is_ready_for_reading
  188. assert len(_DataManager._get(in_mem_dn.id).properties) == 1
  189. assert _DataManager._get(in_mem_dn.id).properties.get("other_data") == "foo"
  190. assert _DataManager._get(in_mem_dn.id).properties == in_mem_dn.properties
  191. assert _DataManager._get(in_mem_dn) is not None
  192. assert _DataManager._get(in_mem_dn).id == in_mem_dn.id
  193. assert _DataManager._get(in_mem_dn).config_id == "baz"
  194. assert _DataManager._get(in_mem_dn).config_id == in_mem_dn.config_id
  195. assert _DataManager._get(in_mem_dn).scope == Scope.SCENARIO
  196. assert _DataManager._get(in_mem_dn).scope == in_mem_dn.scope
  197. assert _DataManager._get(in_mem_dn).owner_id == "Scenario_id"
  198. assert _DataManager._get(in_mem_dn).owner_id == in_mem_dn.owner_id
  199. assert _DataManager._get(in_mem_dn).parent_ids == {"task_id"}
  200. assert _DataManager._get(in_mem_dn).parent_ids == in_mem_dn.parent_ids
  201. assert _DataManager._get(in_mem_dn).last_edit_date is not None
  202. assert _DataManager._get(in_mem_dn).last_edit_date == in_mem_dn.last_edit_date
  203. assert _DataManager._get(in_mem_dn).job_ids == []
  204. assert _DataManager._get(in_mem_dn).job_ids == in_mem_dn.job_ids
  205. assert _DataManager._get(in_mem_dn).is_ready_for_reading
  206. assert _DataManager._get(in_mem_dn).is_ready_for_reading == in_mem_dn.is_ready_for_reading
  207. assert len(_DataManager._get(in_mem_dn).properties) == 1
  208. assert _DataManager._get(in_mem_dn).properties.get("other_data") == "foo"
  209. assert _DataManager._get(in_mem_dn).properties == in_mem_dn.properties
  210. def test_create_and_get_pickle_data_node(self):
  211. # Test we can instantiate a PickleDataNode from DataNodeConfig with :
  212. # - an in_memory type
  213. # - a business cycle scope
  214. # - No owner id
  215. # - no default data
  216. dn_config = Config.configure_data_node(id="plop", storage_type="pickle", scope=Scope.CYCLE)
  217. pickle_dn = _DataManager._create(dn_config, None, {"task_id_1", "task_id_2"})
  218. assert isinstance(pickle_dn, PickleDataNode)
  219. assert isinstance(_DataManager._get(pickle_dn.id), PickleDataNode)
  220. assert _DataManager._exists(pickle_dn.id)
  221. assert _DataManager._get(pickle_dn.id) is not None
  222. assert _DataManager._get(pickle_dn.id).id == pickle_dn.id
  223. assert _DataManager._get(pickle_dn.id).config_id == "plop"
  224. assert _DataManager._get(pickle_dn.id).config_id == pickle_dn.config_id
  225. assert _DataManager._get(pickle_dn.id).scope == Scope.CYCLE
  226. assert _DataManager._get(pickle_dn.id).scope == pickle_dn.scope
  227. assert _DataManager._get(pickle_dn.id).owner_id is None
  228. assert _DataManager._get(pickle_dn.id).owner_id == pickle_dn.owner_id
  229. assert _DataManager._get(pickle_dn.id).parent_ids == {"task_id_1", "task_id_2"}
  230. assert _DataManager._get(pickle_dn.id).parent_ids == pickle_dn.parent_ids
  231. assert _DataManager._get(pickle_dn.id).last_edit_date is None
  232. assert _DataManager._get(pickle_dn.id).last_edit_date == pickle_dn.last_edit_date
  233. assert _DataManager._get(pickle_dn.id).job_ids == []
  234. assert _DataManager._get(pickle_dn.id).job_ids == pickle_dn.job_ids
  235. assert not _DataManager._get(pickle_dn.id).is_ready_for_reading
  236. assert _DataManager._get(pickle_dn.id).is_ready_for_reading == pickle_dn.is_ready_for_reading
  237. assert len(_DataManager._get(pickle_dn.id).properties) == 2 # is_generated and path
  238. assert _DataManager._get(pickle_dn.id).properties == pickle_dn.properties
  239. assert _DataManager._get(pickle_dn) is not None
  240. assert _DataManager._get(pickle_dn).id == pickle_dn.id
  241. assert _DataManager._get(pickle_dn).config_id == "plop"
  242. assert _DataManager._get(pickle_dn).config_id == pickle_dn.config_id
  243. assert _DataManager._get(pickle_dn).scope == Scope.CYCLE
  244. assert _DataManager._get(pickle_dn).scope == pickle_dn.scope
  245. assert _DataManager._get(pickle_dn).owner_id is None
  246. assert _DataManager._get(pickle_dn).owner_id == pickle_dn.owner_id
  247. assert _DataManager._get(pickle_dn).parent_ids == {"task_id_1", "task_id_2"}
  248. assert _DataManager._get(pickle_dn).parent_ids == pickle_dn.parent_ids
  249. assert _DataManager._get(pickle_dn).last_edit_date is None
  250. assert _DataManager._get(pickle_dn).last_edit_date == pickle_dn.last_edit_date
  251. assert _DataManager._get(pickle_dn).job_ids == []
  252. assert _DataManager._get(pickle_dn).job_ids == pickle_dn.job_ids
  253. assert not _DataManager._get(pickle_dn).is_ready_for_reading
  254. assert _DataManager._get(pickle_dn).is_ready_for_reading == pickle_dn.is_ready_for_reading
  255. assert len(_DataManager._get(pickle_dn).properties) == 2 # is_generated and path
  256. assert _DataManager._get(pickle_dn).properties == pickle_dn.properties
  257. def test_create_raises_exception_with_wrong_type(self):
  258. wrong_type_dn_config = DataNodeConfig(id="foo", storage_type="bar", scope=DataNodeConfig._DEFAULT_SCOPE)
  259. with pytest.raises(InvalidDataNodeType):
  260. _DataManager._create(wrong_type_dn_config, None, None)
  261. def test_create_from_same_config_generates_new_data_node_and_new_id(self):
  262. dn_config = Config.configure_data_node(id="foo", storage_type="in_memory")
  263. dn = _DataManager._create(dn_config, None, None)
  264. dn_2 = _DataManager._create(dn_config, None, None)
  265. assert dn_2.id != dn.id
  266. def test_create_uses_overridden_attributes_in_config_file(self):
  267. Config.override(os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/config.toml"))
  268. csv_dn_cfg = Config.configure_data_node(id="foo", storage_type="csv", path="bar", has_header=True)
  269. csv_dn = _DataManager._create(csv_dn_cfg, None, None)
  270. assert csv_dn.config_id == "foo"
  271. assert isinstance(csv_dn, CSVDataNode)
  272. assert csv_dn._path == "path_from_config_file"
  273. assert csv_dn.properties["has_header"]
  274. csv_dn_cfg = Config.configure_data_node(id="baz", storage_type="csv", path="bar", has_header=True)
  275. csv_dn = _DataManager._create(csv_dn_cfg, None, None)
  276. assert csv_dn.config_id == "baz"
  277. assert isinstance(csv_dn, CSVDataNode)
  278. assert csv_dn._path == "bar"
  279. assert csv_dn.properties["has_header"]
  280. def test_get_if_not_exists(self):
  281. with pytest.raises(ModelNotFound):
  282. _DataManager._repository._load("test_data_node_2")
  283. def test_get_all(self):
  284. assert len(_DataManager._get_all()) == 0
  285. dn_config_1 = Config.configure_data_node(id="foo", storage_type="in_memory")
  286. _DataManager._create(dn_config_1, None, None)
  287. assert len(_DataManager._get_all()) == 1
  288. dn_config_2 = Config.configure_data_node(id="baz", storage_type="in_memory")
  289. _DataManager._create(dn_config_2, None, None)
  290. _DataManager._create(dn_config_2, None, None)
  291. assert len(_DataManager._get_all()) == 3
  292. assert len([dn for dn in _DataManager._get_all() if dn.config_id == "foo"]) == 1
  293. assert len([dn for dn in _DataManager._get_all() if dn.config_id == "baz"]) == 2
  294. def test_get_all_on_multiple_versions_environment(self):
  295. # Create 5 data nodes with 2 versions each
  296. # Only version 1.0 has the data node with config_id = "config_id_1"
  297. # Only version 2.0 has the data node with config_id = "config_id_6"
  298. for version in range(1, 3):
  299. for i in range(5):
  300. _DataManager._repository._save(
  301. InMemoryDataNode(
  302. f"config_id_{i + version}",
  303. Scope.SCENARIO,
  304. id=DataNodeId(f"id{i}_v{version}"),
  305. version=f"{version}.0",
  306. )
  307. )
  308. _VersionManager._set_experiment_version("1.0")
  309. assert len(_DataManager._get_all()) == 5
  310. assert len(_DataManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_1"}])) == 1
  311. assert len(_DataManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_6"}])) == 0
  312. _VersionManager._set_development_version("1.0")
  313. assert len(_DataManager._get_all()) == 5
  314. assert len(_DataManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_1"}])) == 1
  315. assert len(_DataManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_6"}])) == 0
  316. _VersionManager._set_experiment_version("2.0")
  317. assert len(_DataManager._get_all()) == 5
  318. assert len(_DataManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_1"}])) == 0
  319. assert len(_DataManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_6"}])) == 1
  320. _VersionManager._set_development_version("2.0")
  321. assert len(_DataManager._get_all()) == 5
  322. assert len(_DataManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_1"}])) == 0
  323. assert len(_DataManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_6"}])) == 1
  324. def test_save_and_update(self):
  325. dn = InMemoryDataNode(
  326. "config_id",
  327. Scope.SCENARIO,
  328. id=DataNodeId("id"),
  329. owner_id=None,
  330. parent_ids={"task_id_1"},
  331. last_edit_date=None,
  332. edits=[],
  333. edit_in_progress=False,
  334. properties={"foo": "bar"},
  335. )
  336. assert len(_DataManager._get_all()) == 0
  337. assert not _DataManager._exists(dn.id)
  338. _DataManager._repository._save(dn)
  339. assert len(_DataManager._get_all()) == 1
  340. assert _DataManager._exists(dn.id)
  341. # changing data node attribute
  342. dn._config_id = "foo"
  343. assert dn.config_id == "foo"
  344. _DataManager._update(dn)
  345. assert len(_DataManager._get_all()) == 1
  346. assert dn.config_id == "foo"
  347. assert _DataManager._get(dn.id).config_id == "foo"
  348. def test_delete(self):
  349. dn_1 = InMemoryDataNode("config_id", Scope.SCENARIO, id=DataNodeId("id_1"))
  350. dn_2 = InMemoryDataNode("config_id", Scope.SCENARIO, id=DataNodeId("id_2"))
  351. dn_3 = InMemoryDataNode("config_id", Scope.SCENARIO, id=DataNodeId("id_3"))
  352. assert len(_DataManager._get_all()) == 0
  353. _DataManager._repository._save(dn_1)
  354. _DataManager._repository._save(dn_2)
  355. _DataManager._repository._save(dn_3)
  356. assert len(_DataManager._get_all()) == 3
  357. assert all(_DataManager._exists(dn.id) for dn in [dn_1, dn_2, dn_3])
  358. _DataManager._delete(dn_1.id)
  359. assert len(_DataManager._get_all()) == 2
  360. assert _DataManager._get(dn_2.id).id == dn_2.id
  361. assert _DataManager._get(dn_3.id).id == dn_3.id
  362. assert _DataManager._get(dn_1.id) is None
  363. assert all(_DataManager._exists(dn.id) for dn in [dn_2, dn_3])
  364. assert not _DataManager._exists(dn_1.id)
  365. _DataManager._delete_all()
  366. assert len(_DataManager._get_all()) == 0
  367. assert not any(_DataManager._exists(dn.id) for dn in [dn_2, dn_3])
  368. def test_get_or_create(self):
  369. def _get_or_create_dn(config, *args):
  370. return _DataManager._bulk_get_or_create([config], *args)[config]
  371. _DataManager._delete_all()
  372. global_dn_config = Config.configure_data_node(
  373. id="test_data_node", storage_type="in_memory", scope=Scope.GLOBAL, data="In memory Data Node"
  374. )
  375. cycle_dn_config = Config.configure_data_node(
  376. id="test_data_node1", storage_type="in_memory", scope=Scope.CYCLE, data="In memory Data Node"
  377. )
  378. scenario_dn_config = Config.configure_data_node(
  379. id="test_data_node2", storage_type="in_memory", scope=Scope.SCENARIO, data="In memory scenario"
  380. )
  381. assert len(_DataManager._get_all()) == 0
  382. global_dn = _get_or_create_dn(global_dn_config, None, None)
  383. assert len(_DataManager._get_all()) == 1
  384. global_dn_bis = _get_or_create_dn(global_dn_config, None)
  385. assert len(_DataManager._get_all()) == 1
  386. assert global_dn.id == global_dn_bis.id
  387. scenario_dn = _get_or_create_dn(scenario_dn_config, None, "scenario_id")
  388. assert len(_DataManager._get_all()) == 2
  389. scenario_dn_bis = _get_or_create_dn(scenario_dn_config, None, "scenario_id")
  390. assert len(_DataManager._get_all()) == 2
  391. assert scenario_dn.id == scenario_dn_bis.id
  392. scenario_dn_ter = _get_or_create_dn(scenario_dn_config, None, "scenario_id")
  393. assert len(_DataManager._get_all()) == 2
  394. assert scenario_dn.id == scenario_dn_bis.id
  395. assert scenario_dn_bis.id == scenario_dn_ter.id
  396. scenario_dn_quater = _get_or_create_dn(scenario_dn_config, None, "scenario_id_2")
  397. assert len(_DataManager._get_all()) == 3
  398. assert scenario_dn.id == scenario_dn_bis.id
  399. assert scenario_dn_bis.id == scenario_dn_ter.id
  400. assert scenario_dn_ter.id != scenario_dn_quater.id
  401. assert len(_DataManager._get_all()) == 3
  402. cycle_dn = _get_or_create_dn(cycle_dn_config, "cycle_id", None)
  403. assert len(_DataManager._get_all()) == 4
  404. cycle_dn_1 = _get_or_create_dn(cycle_dn_config, "cycle_id", None)
  405. assert len(_DataManager._get_all()) == 4
  406. assert cycle_dn.id == cycle_dn_1.id
  407. cycle_dn_2 = _get_or_create_dn(cycle_dn_config, "cycle_id", "scenario_id")
  408. assert len(_DataManager._get_all()) == 4
  409. assert cycle_dn.id == cycle_dn_2.id
  410. cycle_dn_3 = _get_or_create_dn(cycle_dn_config, "cycle_id", None)
  411. assert len(_DataManager._get_all()) == 4
  412. assert cycle_dn.id == cycle_dn_3.id
  413. cycle_dn_4 = _get_or_create_dn(cycle_dn_config, "cycle_id", "scenario_id")
  414. assert len(_DataManager._get_all()) == 4
  415. assert cycle_dn.id == cycle_dn_4.id
  416. cycle_dn_5 = _get_or_create_dn(cycle_dn_config, "cycle_id", "scenario_id_2")
  417. assert len(_DataManager._get_all()) == 4
  418. assert cycle_dn.id == cycle_dn_5.id
  419. assert cycle_dn_1.id == cycle_dn_2.id
  420. assert cycle_dn_2.id == cycle_dn_3.id
  421. assert cycle_dn_3.id == cycle_dn_4.id
  422. assert cycle_dn_4.id == cycle_dn_5.id
  423. def test_ensure_persistence_of_data_node(self):
  424. dm = _DataManager()
  425. dm._delete_all()
  426. dn_config_1 = Config.configure_data_node(
  427. id="data_node_1", storage_type="in_memory", data="In memory sequence 2"
  428. )
  429. dn_config_2 = Config.configure_data_node(
  430. id="data_node_2", storage_type="in_memory", data="In memory sequence 2"
  431. )
  432. dm._bulk_get_or_create([dn_config_1, dn_config_2])
  433. assert len(dm._get_all()) == 2
  434. # Delete the DataManager to ensure it's get from the storage system
  435. del dm
  436. dm = _DataManager()
  437. dm._bulk_get_or_create([dn_config_1])
  438. assert len(dm._get_all()) == 2
  439. dm._delete_all()
  440. @pytest.mark.parametrize(
  441. "storage_type,path",
  442. [
  443. ("pickle", "pickle_file_path"),
  444. ("csv", "csv_file"),
  445. ("excel", "excel_file"),
  446. ("json", "json_file"),
  447. ("parquet", "parquet_file_path"),
  448. ],
  449. )
  450. def test_read(self, storage_type, path, request):
  451. path = request.getfixturevalue(path)
  452. non_exist_dn_config = Config.configure_data_node(id="d1", storage_type=storage_type, path="non_exist_path")
  453. dn_config = Config.configure_data_node(id="d2", storage_type=storage_type, path=path)
  454. dn_1 = _DataManager._create(non_exist_dn_config, None, None)
  455. dn_2 = _DataManager._create(dn_config, None, None)
  456. with pytest.raises(NoData):
  457. _DataManager._read(dn_1)
  458. assert dn_2._read() is not None
  459. @pytest.mark.parametrize(
  460. "storage_type,path",
  461. [
  462. ("pickle", "pickle_file_path"),
  463. ("csv", "csv_file"),
  464. ("parquet", "parquet_file_path"),
  465. ],
  466. )
  467. def test_write(self, storage_type, path, request):
  468. path = request.getfixturevalue(path)
  469. dn_config = Config.configure_data_node(id="d2", storage_type=storage_type, path=path)
  470. dn = _DataManager._create(dn_config, None, None)
  471. new_data = pd.DataFrame([{"a": 11, "b": 12, "c": 13}, {"a": 14, "b": 15, "c": 16}])
  472. _DataManager._write(dn, new_data)
  473. assert_frame_equal(dn._read(), new_data)
  474. @pytest.mark.parametrize(
  475. "storage_type,path",
  476. [
  477. ("csv", "csv_file"),
  478. ("parquet", "parquet_file_path"),
  479. ],
  480. )
  481. def test_append(self, storage_type, path, request):
  482. path = request.getfixturevalue(path)
  483. dn_config = Config.configure_data_node(id="d2", storage_type=storage_type, path=path)
  484. dn = _DataManager._create(dn_config, None, None)
  485. old_data = _DataManager._read(dn)
  486. new_data = pd.DataFrame([{"a": 11, "b": 12, "c": 13}, {"a": 14, "b": 15, "c": 16}])
  487. _DataManager._append(dn, new_data)
  488. assert_frame_equal(dn._read(), pd.concat([old_data, new_data], ignore_index=True))
  489. @pytest.mark.parametrize(
  490. "storage_type,path",
  491. [
  492. ("pickle", "pickle_file_path"),
  493. ("csv", "csv_file"),
  494. ("excel", "excel_file"),
  495. ("json", "json_file"),
  496. ("parquet", "parquet_file_path"),
  497. ],
  498. )
  499. def test_clean_generated_files(self, storage_type, path, request):
  500. path = request.getfixturevalue(path)
  501. user_dn_config = Config.configure_data_node(
  502. id="d1", storage_type=storage_type, path=path, default_data={"a": [1], "b": [2]}
  503. )
  504. generated_dn_1_config = Config.configure_data_node(
  505. id="d2", storage_type=storage_type, default_data={"a": [1], "b": [2]}
  506. )
  507. generated_dn_2_config = Config.configure_data_node(
  508. id="d3", storage_type=storage_type, default_data={"a": [1], "b": [2]}
  509. )
  510. dns = _DataManager._bulk_get_or_create([user_dn_config, generated_dn_1_config, generated_dn_2_config])
  511. user_dn = dns[user_dn_config]
  512. generated_dn_1 = dns[generated_dn_1_config]
  513. generated_dn_2 = dns[generated_dn_2_config]
  514. _DataManager._clean_generated_file(user_dn.id)
  515. assert file_exists(user_dn.path)
  516. _DataManager._clean_generated_files([generated_dn_1, generated_dn_2])
  517. assert not file_exists(generated_dn_1.path)
  518. assert not file_exists(generated_dn_2.path)
  519. @pytest.mark.parametrize(
  520. "storage_type,path",
  521. [
  522. ("pickle", "pickle_file_path"),
  523. ("csv", "csv_file"),
  524. ("excel", "excel_file"),
  525. ("json", "json_file"),
  526. ("parquet", "parquet_file_path"),
  527. ],
  528. )
  529. def test_delete_does_clean_generated_pickle_files(self, storage_type, path, request):
  530. path = request.getfixturevalue(path)
  531. user_dn_config = Config.configure_data_node(
  532. id="d1", storage_type=storage_type, path=path, default_data={"a": [1], "b": [2]}
  533. )
  534. generated_dn_config_1 = Config.configure_data_node(
  535. id="d2", storage_type=storage_type, default_data={"a": [1], "b": [2]}
  536. )
  537. generated_dn_config_2 = Config.configure_data_node(
  538. id="d3", storage_type=storage_type, default_data={"a": [1], "b": [2]}
  539. )
  540. generated_dn_config_3 = Config.configure_data_node(
  541. id="d4", storage_type=storage_type, default_data={"a": [1], "b": [2]}
  542. )
  543. dns = _DataManager._bulk_get_or_create(
  544. [
  545. user_dn_config,
  546. generated_dn_config_1,
  547. generated_dn_config_2,
  548. generated_dn_config_3,
  549. ]
  550. )
  551. user_dn = dns[user_dn_config]
  552. generated_dn_1 = dns[generated_dn_config_1]
  553. generated_dn_2 = dns[generated_dn_config_2]
  554. generated_dn_3 = dns[generated_dn_config_3]
  555. _DataManager._delete(user_dn.id)
  556. assert file_exists(user_dn.path)
  557. _DataManager._delete_many([generated_dn_1.id, generated_dn_2.id])
  558. assert not file_exists(generated_dn_1.path)
  559. assert not file_exists(generated_dn_2.path)
  560. _DataManager._delete_all()
  561. assert not file_exists(generated_dn_3.path)
  562. def test_create_dn_from_loaded_config_no_scope(self):
  563. file_config = NamedTemporaryFile(
  564. """
  565. [TAIPY]
  566. [DATA_NODE.a]
  567. default_data = "4:int"
  568. [DATA_NODE.b]
  569. [TASK.t]
  570. function = "math.sqrt:function"
  571. inputs = [ "a:SECTION",]
  572. outputs = [ "b:SECTION",]
  573. skippable = "False:bool"
  574. [SCENARIO.s]
  575. tasks = [ "t:SECTION",]
  576. sequences.s_sequence = [ "t:SECTION",]
  577. [SCENARIO.s.comparators]
  578. """
  579. )
  580. from taipy import core as tp
  581. Config.override(file_config.filename)
  582. tp.create_scenario(Config.scenarios["s"])
  583. tp.create_scenario(Config.scenarios["s"])
  584. assert len(tp.get_data_nodes()) == 4
  585. def test_create_dn_from_loaded_config_no_storage_type(self):
  586. file_config = NamedTemporaryFile(
  587. """
  588. [TAIPY]
  589. [DATA_NODE.input_dn]
  590. scope = "SCENARIO:SCOPE"
  591. default_data = "21:int"
  592. [DATA_NODE.output_dn]
  593. storage_type = "in_memory"
  594. scope = "SCENARIO:SCOPE"
  595. [TASK.double]
  596. inputs = [ "input_dn:SECTION",]
  597. function = "math.sqrt:function"
  598. outputs = [ "output_dn:SECTION",]
  599. skippable = "False:bool"
  600. [SCENARIO.my_scenario]
  601. tasks = [ "double:SECTION",]
  602. sequences.my_sequence = [ "double:SECTION",]
  603. [SCENARIO.my_scenario.comparators]
  604. """
  605. )
  606. from taipy import core as tp
  607. Config.override(file_config.filename)
  608. scenario = tp.create_scenario(Config.scenarios["my_scenario"])
  609. assert isinstance(scenario.input_dn, PickleDataNode)
  610. assert isinstance(scenario.output_dn, InMemoryDataNode)
  611. def test_create_dn_from_loaded_config_modified_default_config(self):
  612. file_config = NamedTemporaryFile(
  613. """
  614. [TAIPY]
  615. [DATA_NODE.input_dn]
  616. scope = "SCENARIO:SCOPE"
  617. default_path="fake/path.csv"
  618. [DATA_NODE.output_dn]
  619. storage_type = "in_memory"
  620. scope = "SCENARIO:SCOPE"
  621. [TASK.double]
  622. inputs = [ "input_dn:SECTION",]
  623. function = "math.sqrt:function"
  624. outputs = [ "output_dn:SECTION",]
  625. skippable = "False:bool"
  626. [SCENARIO.my_scenario]
  627. tasks = [ "double:SECTION",]
  628. sequences.my_sequence = [ "double:SECTION",]
  629. [SCENARIO.my_scenario.comparators]
  630. """
  631. )
  632. from taipy import core as tp
  633. Config.set_default_data_node_configuration(storage_type="csv")
  634. Config.override(file_config.filename)
  635. scenario = tp.create_scenario(Config.scenarios["my_scenario"])
  636. assert isinstance(scenario.input_dn, CSVDataNode)
  637. assert isinstance(scenario.output_dn, InMemoryDataNode)
  638. def test_get_tasks_by_config_id(self):
  639. dn_config_1 = Config.configure_data_node("dn_1", scope=Scope.SCENARIO)
  640. dn_config_2 = Config.configure_data_node("dn_2", scope=Scope.SCENARIO)
  641. dn_config_3 = Config.configure_data_node("dn_3", scope=Scope.SCENARIO)
  642. dn_1_1 = _DataManager._create(dn_config_1, None, None)
  643. dn_1_2 = _DataManager._create(dn_config_1, None, None)
  644. dn_1_3 = _DataManager._create(dn_config_1, None, None)
  645. assert len(_DataManager._get_all()) == 3
  646. dn_2_1 = _DataManager._create(dn_config_2, None, None)
  647. dn_2_2 = _DataManager._create(dn_config_2, None, None)
  648. assert len(_DataManager._get_all()) == 5
  649. dn_3_1 = _DataManager._create(dn_config_3, None, None)
  650. assert len(_DataManager._get_all()) == 6
  651. dn_1_datanodes = _DataManager._get_by_config_id(dn_config_1.id)
  652. assert len(dn_1_datanodes) == 3
  653. assert sorted([dn_1_1.id, dn_1_2.id, dn_1_3.id]) == sorted([sequence.id for sequence in dn_1_datanodes])
  654. dn_2_datanodes = _DataManager._get_by_config_id(dn_config_2.id)
  655. assert len(dn_2_datanodes) == 2
  656. assert sorted([dn_2_1.id, dn_2_2.id]) == sorted([sequence.id for sequence in dn_2_datanodes])
  657. dn_3_datanodes = _DataManager._get_by_config_id(dn_config_3.id)
  658. assert len(dn_3_datanodes) == 1
  659. assert sorted([dn_3_1.id]) == sorted([sequence.id for sequence in dn_3_datanodes])
  660. def test_get_data_nodes_by_config_id_in_multiple_versions_environment(self):
  661. dn_config_1 = Config.configure_data_node("dn_1", scope=Scope.SCENARIO)
  662. dn_config_2 = Config.configure_data_node("dn_2", scope=Scope.SCENARIO)
  663. _VersionManager._set_experiment_version("1.0")
  664. _DataManager._create(dn_config_1, None, None)
  665. _DataManager._create(dn_config_1, None, None)
  666. _DataManager._create(dn_config_1, None, None)
  667. _DataManager._create(dn_config_2, None, None)
  668. _DataManager._create(dn_config_2, None, None)
  669. assert len(_DataManager._get_by_config_id(dn_config_1.id)) == 3
  670. assert len(_DataManager._get_by_config_id(dn_config_2.id)) == 2
  671. _VersionManager._set_experiment_version("2.0")
  672. _DataManager._create(dn_config_1, None, None)
  673. _DataManager._create(dn_config_1, None, None)
  674. _DataManager._create(dn_config_1, None, None)
  675. _DataManager._create(dn_config_2, None, None)
  676. _DataManager._create(dn_config_2, None, None)
  677. assert len(_DataManager._get_by_config_id(dn_config_1.id)) == 3
  678. assert len(_DataManager._get_by_config_id(dn_config_2.id)) == 2
  679. def test_can_duplicate(self):
  680. dn_config = Config.configure_data_node("dn_1")
  681. dn = _DataManager._create(dn_config, None, None)
  682. reasons = _DataManager._can_duplicate(dn.id)
  683. assert bool(reasons)
  684. assert reasons._reasons == {}
  685. reasons = _DataManager._can_duplicate(dn)
  686. assert bool(reasons)
  687. assert reasons._reasons == {}
  688. reasons = _DataManager._can_duplicate("1")
  689. assert not bool(reasons)
  690. assert reasons._reasons["1"] == {EntityDoesNotExist("1")}
  691. assert str(list(reasons._reasons["1"])[0]) == "Entity '1' does not exist in the repository"