test_sql_data_node.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. from importlib import util
  12. from unittest.mock import patch
  13. import numpy as np
  14. import pandas as pd
  15. import pytest
  16. from pandas.testing import assert_frame_equal
  17. from taipy import Scope
  18. from taipy.common.config import Config
  19. from taipy.core.data._data_manager_factory import _DataManagerFactory
  20. from taipy.core.data.data_node_id import DataNodeId
  21. from taipy.core.data.operator import JoinOperator, Operator
  22. from taipy.core.data.sql import SQLDataNode
  23. from taipy.core.exceptions.exceptions import (
  24. MissingAppendQueryBuilder,
  25. MissingReadQuery,
  26. MissingRequiredProperty,
  27. MissingWriteQueryBuilder,
  28. UnknownDatabaseEngine,
  29. )
  30. class MyCustomObject:
  31. def __init__(self, foo=None, bar=None, *args, **kwargs):
  32. self.foo = foo
  33. self.bar = bar
  34. self.args = args
  35. self.kwargs = kwargs
  36. def my_write_query_builder_with_pandas(data: pd.DataFrame):
  37. insert_data = data.to_dict("records")
  38. return ["DELETE FROM example", ("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  39. def my_append_query_builder_with_pandas(data: pd.DataFrame):
  40. insert_data = data.to_dict("records")
  41. return [("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  42. def single_write_query_builder(data):
  43. return "DELETE FROM example"
  44. class TestSQLDataNode:
  45. __sql_properties = [
  46. {
  47. "db_name": "taipy.sqlite3",
  48. "db_engine": "sqlite",
  49. "read_query": "SELECT * FROM example",
  50. "write_query_builder": my_write_query_builder_with_pandas,
  51. "db_extra_args": {
  52. "TrustServerCertificate": "yes",
  53. "other": "value",
  54. },
  55. },
  56. ]
  57. if util.find_spec("pyodbc"):
  58. __sql_properties.append(
  59. {
  60. "db_username": "sa",
  61. "db_password": "Passw0rd",
  62. "db_name": "taipy",
  63. "db_engine": "mssql",
  64. "read_query": "SELECT * FROM example",
  65. "write_query_builder": my_write_query_builder_with_pandas,
  66. "db_extra_args": {
  67. "TrustServerCertificate": "yes",
  68. },
  69. },
  70. )
  71. if util.find_spec("pymysql"):
  72. __sql_properties.append(
  73. {
  74. "db_username": "sa",
  75. "db_password": "Passw0rd",
  76. "db_name": "taipy",
  77. "db_engine": "mysql",
  78. "read_query": "SELECT * FROM example",
  79. "write_query_builder": my_write_query_builder_with_pandas,
  80. "db_extra_args": {
  81. "TrustServerCertificate": "yes",
  82. },
  83. },
  84. )
  85. if util.find_spec("psycopg2"):
  86. __sql_properties.append(
  87. {
  88. "db_username": "sa",
  89. "db_password": "Passw0rd",
  90. "db_name": "taipy",
  91. "db_engine": "postgresql",
  92. "read_query": "SELECT * FROM example",
  93. "write_query_builder": my_write_query_builder_with_pandas,
  94. "db_extra_args": {
  95. "TrustServerCertificate": "yes",
  96. },
  97. },
  98. )
  99. @pytest.mark.parametrize("properties", __sql_properties)
  100. def test_create(self, properties):
  101. sql_dn_config = Config.configure_sql_data_node(id="foo_bar", **properties)
  102. dn = _DataManagerFactory._build_manager()._create(sql_dn_config, None, None)
  103. assert isinstance(dn, SQLDataNode)
  104. assert dn.storage_type() == "sql"
  105. assert dn.config_id == "foo_bar"
  106. assert dn.scope == Scope.SCENARIO
  107. assert dn.id is not None
  108. assert dn.owner_id is None
  109. assert dn.job_ids == []
  110. assert dn.is_ready_for_reading
  111. assert dn.properties["exposed_type"] == "pandas"
  112. assert dn.properties["read_query"] == "SELECT * FROM example"
  113. assert dn.properties["write_query_builder"] == my_write_query_builder_with_pandas
  114. sql_dn_config_1 = Config.configure_sql_data_node(
  115. id="foo",
  116. **properties,
  117. append_query_builder=my_append_query_builder_with_pandas,
  118. exposed_type=MyCustomObject,
  119. )
  120. dn_1 = _DataManagerFactory._build_manager()._create(sql_dn_config_1, None, None)
  121. assert isinstance(dn, SQLDataNode)
  122. assert dn_1.properties["exposed_type"] == MyCustomObject
  123. assert dn_1.properties["append_query_builder"] == my_append_query_builder_with_pandas
  124. @pytest.mark.parametrize("properties", __sql_properties)
  125. def test_get_user_properties(self, properties):
  126. custom_properties = properties.copy()
  127. custom_properties["foo"] = "bar"
  128. dn = SQLDataNode(
  129. "foo_bar",
  130. Scope.SCENARIO,
  131. properties=custom_properties,
  132. )
  133. assert dn._get_user_properties() == {"foo": "bar"}
  134. @pytest.mark.parametrize(
  135. "properties",
  136. [
  137. {},
  138. {"read_query": "ready query"},
  139. {"read_query": "ready query", "write_query_builder": "write query"},
  140. {"read_query": "ready query", "write_query_builder": "write query", "db_username": "foo"},
  141. {
  142. "read_query": "ready query",
  143. "write_query_builder": "write query",
  144. "db_username": "foo",
  145. "db_password": "foo",
  146. },
  147. {
  148. "read_query": "ready query",
  149. "write_query_builder": "write query",
  150. "db_username": "foo",
  151. "db_password": "foo",
  152. "db_name": "foo",
  153. },
  154. {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "some engine"},
  155. {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "sqlite"},
  156. {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "mssql", "db_name": "foo"},
  157. {
  158. "read_query": "ready query",
  159. "write_query_builder": "write query",
  160. "db_engine": "mysql",
  161. "db_username": "foo",
  162. },
  163. {
  164. "read_query": "ready query",
  165. "write_query_builder": "write query",
  166. "db_engine": "postgresql",
  167. "db_username": "foo",
  168. "db_password": "foo",
  169. },
  170. ],
  171. )
  172. def test_create_with_missing_parameters(self, properties):
  173. with pytest.raises(MissingRequiredProperty):
  174. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"))
  175. engine = properties.get("db_engine")
  176. if engine is not None and engine not in ["sqlite", "mssql", "mysql", "postgresql"]:
  177. with pytest.raises(UnknownDatabaseEngine):
  178. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
  179. else:
  180. with pytest.raises(MissingRequiredProperty):
  181. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
  182. @pytest.mark.parametrize("properties", __sql_properties)
  183. def test_write_query_builder(self, properties):
  184. custom_properties = properties.copy()
  185. custom_properties.pop("db_extra_args")
  186. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  187. _DataManagerFactory._build_manager()._repository._save(dn)
  188. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  189. # mock connection execute
  190. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  191. assert len(engine_mock.mock_calls[4].args) == 1
  192. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  193. assert len(engine_mock.mock_calls[5].args) == 2
  194. assert engine_mock.mock_calls[5].args[0].text == "INSERT INTO example VALUES (:foo, :bar)"
  195. assert engine_mock.mock_calls[5].args[1] == [
  196. {"foo": 1, "bar": 4},
  197. {"foo": 2, "bar": 5},
  198. {"foo": 3, "bar": 6},
  199. ]
  200. custom_properties["write_query_builder"] = single_write_query_builder
  201. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  202. _DataManagerFactory._build_manager()._repository._save(dn)
  203. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  204. # mock connection execute
  205. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  206. assert len(engine_mock.mock_calls[4].args) == 1
  207. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  208. @pytest.mark.parametrize("properties", __sql_properties)
  209. def test_write_only_datanode(self, properties):
  210. custom_properties = properties.copy()
  211. custom_properties.pop("read_query")
  212. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  213. _DataManagerFactory._build_manager()._repository._save(dn)
  214. with patch("sqlalchemy.engine.Engine.connect"):
  215. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  216. with pytest.raises(MissingReadQuery):
  217. dn.read()
  218. with pytest.raises(MissingAppendQueryBuilder):
  219. dn.append(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  220. @pytest.mark.parametrize(
  221. "tmp_sqlite_path",
  222. [
  223. "tmp_sqlite_db_file_path",
  224. "tmp_sqlite_sqlite3_file_path",
  225. ],
  226. )
  227. def test_sqlite_read_file_with_different_extension(self, tmp_sqlite_path, request):
  228. tmp_sqlite_path = request.getfixturevalue(tmp_sqlite_path)
  229. folder_path, db_name, file_extension = tmp_sqlite_path
  230. properties = {
  231. "db_engine": "sqlite",
  232. "read_query": "SELECT * from example",
  233. "write_query_builder": single_write_query_builder,
  234. "db_name": db_name,
  235. "sqlite_folder_path": folder_path,
  236. "sqlite_file_extension": file_extension,
  237. }
  238. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  239. data = dn.read()
  240. assert data.equals(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  241. @pytest.mark.parametrize("properties", __sql_properties)
  242. def test_read_only_datanode(self, properties):
  243. custom_properties = properties.copy()
  244. custom_properties.pop("write_query_builder")
  245. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  246. _DataManagerFactory._build_manager()._repository._save(dn)
  247. with patch("sqlalchemy.engine.Engine.connect"):
  248. dn.read()
  249. with pytest.raises(MissingWriteQueryBuilder):
  250. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  251. with pytest.raises(MissingAppendQueryBuilder):
  252. dn.append(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  253. def test_sqlite_append_pandas(self, tmp_sqlite_sqlite3_file_path):
  254. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  255. properties = {
  256. "db_engine": "sqlite",
  257. "read_query": "SELECT * FROM example",
  258. "write_query_builder": my_write_query_builder_with_pandas,
  259. "append_query_builder": my_append_query_builder_with_pandas,
  260. "db_name": db_name,
  261. "sqlite_folder_path": folder_path,
  262. "sqlite_file_extension": file_extension,
  263. }
  264. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  265. _DataManagerFactory._build_manager()._repository._save(dn)
  266. original_data = pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}])
  267. data = dn.read()
  268. assert_frame_equal(data, original_data)
  269. append_data_1 = pd.DataFrame([{"foo": 5, "bar": 6}, {"foo": 7, "bar": 8}])
  270. dn.append(append_data_1)
  271. assert_frame_equal(dn.read(), pd.concat([original_data, append_data_1]).reset_index(drop=True))
  272. def test_sqlite_append_without_append_query_builder(self, tmp_sqlite_sqlite3_file_path):
  273. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  274. properties = {
  275. "db_engine": "sqlite",
  276. "read_query": "SELECT * FROM example",
  277. "write_query_builder": my_write_query_builder_with_pandas,
  278. "db_name": db_name,
  279. "sqlite_folder_path": folder_path,
  280. "sqlite_file_extension": file_extension,
  281. }
  282. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  283. with pytest.raises(MissingAppendQueryBuilder):
  284. dn.append(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  285. def test_filter_pandas_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  286. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  287. properties = {
  288. "db_engine": "sqlite",
  289. "read_query": "SELECT * FROM example",
  290. "write_query_builder": my_write_query_builder_with_pandas,
  291. "db_name": db_name,
  292. "sqlite_folder_path": folder_path,
  293. "sqlite_file_extension": file_extension,
  294. "exposed_type": "pandas",
  295. }
  296. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  297. _DataManagerFactory._build_manager()._repository._save(dn)
  298. dn.write(
  299. pd.DataFrame(
  300. [
  301. {"foo": 1, "bar": 1},
  302. {"foo": 1, "bar": 2},
  303. {"foo": 1, "bar": 3},
  304. {"foo": 2, "bar": 1},
  305. {"foo": 2, "bar": 2},
  306. {"foo": 2, "bar": 3},
  307. ]
  308. )
  309. )
  310. # Test datanode indexing and slicing
  311. assert dn["foo"].equals(pd.Series([1, 1, 1, 2, 2, 2]))
  312. assert dn["bar"].equals(pd.Series([1, 2, 3, 1, 2, 3]))
  313. assert dn[:2].equals(pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}]))
  314. # Test filter data
  315. filtered_by_filter_method = dn.filter(("foo", 1, Operator.EQUAL))
  316. filtered_by_indexing = dn[dn["foo"] == 1]
  317. expected_data = pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}, {"foo": 1, "bar": 3}])
  318. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  319. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  320. filtered_by_filter_method = dn.filter(("foo", 1, Operator.NOT_EQUAL))
  321. filtered_by_indexing = dn[dn["foo"] != 1]
  322. expected_data = pd.DataFrame([{"foo": 2, "bar": 1}, {"foo": 2, "bar": 2}, {"foo": 2, "bar": 3}])
  323. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  324. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  325. filtered_by_filter_method = dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  326. filtered_by_indexing = dn[(dn["bar"] == 1) | (dn["bar"] == 2)]
  327. expected_data = pd.DataFrame(
  328. [
  329. {"foo": 1, "bar": 1},
  330. {"foo": 1, "bar": 2},
  331. {"foo": 2, "bar": 1},
  332. {"foo": 2, "bar": 2},
  333. ]
  334. )
  335. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  336. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  337. def test_filter_numpy_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  338. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  339. properties = {
  340. "db_engine": "sqlite",
  341. "read_query": "SELECT * FROM example",
  342. "write_query_builder": my_write_query_builder_with_pandas,
  343. "db_name": db_name,
  344. "sqlite_folder_path": folder_path,
  345. "sqlite_file_extension": file_extension,
  346. "exposed_type": "numpy",
  347. }
  348. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  349. _DataManagerFactory._build_manager()._repository._save(dn)
  350. dn.write(
  351. pd.DataFrame(
  352. [
  353. {"foo": 1, "bar": 1},
  354. {"foo": 1, "bar": 2},
  355. {"foo": 1, "bar": 3},
  356. {"foo": 2, "bar": 1},
  357. {"foo": 2, "bar": 2},
  358. {"foo": 2, "bar": 3},
  359. ]
  360. )
  361. )
  362. # Test datanode indexing and slicing
  363. assert np.array_equal(dn[0], np.array([1, 1]))
  364. assert np.array_equal(dn[1], np.array([1, 2]))
  365. assert np.array_equal(dn[:3], np.array([[1, 1], [1, 2], [1, 3]]))
  366. assert np.array_equal(dn[:, 0], np.array([1, 1, 1, 2, 2, 2]))
  367. assert np.array_equal(dn[1:4, :1], np.array([[1], [1], [2]]))
  368. # Test filter data
  369. assert np.array_equal(dn.filter(("foo", 1, Operator.EQUAL)), np.array([[1, 1], [1, 2], [1, 3]]))
  370. assert np.array_equal(dn[dn[:, 0] == 1], np.array([[1, 1], [1, 2], [1, 3]]))
  371. assert np.array_equal(dn.filter(("foo", 1, Operator.NOT_EQUAL)), np.array([[2, 1], [2, 2], [2, 3]]))
  372. assert np.array_equal(dn[dn[:, 0] != 1], np.array([[2, 1], [2, 2], [2, 3]]))
  373. assert np.array_equal(dn.filter(("bar", 2, Operator.EQUAL)), np.array([[1, 2], [2, 2]]))
  374. assert np.array_equal(dn[dn[:, 1] == 2], np.array([[1, 2], [2, 2]]))
  375. assert np.array_equal(
  376. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR),
  377. np.array([[1, 1], [1, 2], [2, 1], [2, 2]]),
  378. )
  379. assert np.array_equal(dn[(dn[:, 1] == 1) | (dn[:, 1] == 2)], np.array([[1, 1], [1, 2], [2, 1], [2, 2]]))
  380. def test_filter_does_not_read_all_entities(self, tmp_sqlite_sqlite3_file_path):
  381. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  382. properties = {
  383. "db_engine": "sqlite",
  384. "read_query": "SELECT * FROM example",
  385. "write_query_builder": my_write_query_builder_with_pandas,
  386. "db_name": db_name,
  387. "sqlite_folder_path": folder_path,
  388. "sqlite_file_extension": file_extension,
  389. "exposed_type": "numpy",
  390. }
  391. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  392. # SQLDataNode.filter() should not call the MongoCollectionDataNode._read() method
  393. with patch.object(SQLDataNode, "_read") as read_mock:
  394. dn.filter(("foo", 1, Operator.EQUAL))
  395. dn.filter(("bar", 2, Operator.NOT_EQUAL))
  396. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  397. assert read_mock["_read"].call_count == 0