test_sql_data_node.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. from importlib import util
  12. from unittest.mock import patch
  13. import numpy as np
  14. import pandas as pd
  15. import pytest
  16. from pandas.testing import assert_frame_equal
  17. from taipy import Scope
  18. from taipy.common.config import Config
  19. from taipy.core.data._data_manager_factory import _DataManagerFactory
  20. from taipy.core.data.data_node_id import DataNodeId
  21. from taipy.core.data.operator import JoinOperator, Operator
  22. from taipy.core.data.sql import SQLDataNode
  23. from taipy.core.exceptions.exceptions import MissingAppendQueryBuilder, MissingRequiredProperty, UnknownDatabaseEngine
  24. class MyCustomObject:
  25. def __init__(self, foo=None, bar=None, *args, **kwargs):
  26. self.foo = foo
  27. self.bar = bar
  28. self.args = args
  29. self.kwargs = kwargs
  30. def my_write_query_builder_with_pandas(data: pd.DataFrame):
  31. insert_data = data.to_dict("records")
  32. return ["DELETE FROM example", ("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  33. def my_append_query_builder_with_pandas(data: pd.DataFrame):
  34. insert_data = data.to_dict("records")
  35. return [("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  36. def single_write_query_builder(data):
  37. return "DELETE FROM example"
  38. class TestSQLDataNode:
  39. __sql_properties = [
  40. {
  41. "db_name": "taipy.sqlite3",
  42. "db_engine": "sqlite",
  43. "read_query": "SELECT * FROM example",
  44. "write_query_builder": my_write_query_builder_with_pandas,
  45. "db_extra_args": {
  46. "TrustServerCertificate": "yes",
  47. "other": "value",
  48. },
  49. },
  50. ]
  51. if util.find_spec("pyodbc"):
  52. __sql_properties.append(
  53. {
  54. "db_username": "sa",
  55. "db_password": "Passw0rd",
  56. "db_name": "taipy",
  57. "db_engine": "mssql",
  58. "read_query": "SELECT * FROM example",
  59. "write_query_builder": my_write_query_builder_with_pandas,
  60. "db_extra_args": {
  61. "TrustServerCertificate": "yes",
  62. },
  63. },
  64. )
  65. if util.find_spec("pymysql"):
  66. __sql_properties.append(
  67. {
  68. "db_username": "sa",
  69. "db_password": "Passw0rd",
  70. "db_name": "taipy",
  71. "db_engine": "mysql",
  72. "read_query": "SELECT * FROM example",
  73. "write_query_builder": my_write_query_builder_with_pandas,
  74. "db_extra_args": {
  75. "TrustServerCertificate": "yes",
  76. },
  77. },
  78. )
  79. if util.find_spec("psycopg2"):
  80. __sql_properties.append(
  81. {
  82. "db_username": "sa",
  83. "db_password": "Passw0rd",
  84. "db_name": "taipy",
  85. "db_engine": "postgresql",
  86. "read_query": "SELECT * FROM example",
  87. "write_query_builder": my_write_query_builder_with_pandas,
  88. "db_extra_args": {
  89. "TrustServerCertificate": "yes",
  90. },
  91. },
  92. )
  93. @pytest.mark.parametrize("properties", __sql_properties)
  94. def test_create(self, properties):
  95. sql_dn_config = Config.configure_sql_data_node(id="foo_bar", **properties)
  96. dn = _DataManagerFactory._build_manager()._create_and_set(sql_dn_config, None, None)
  97. assert isinstance(dn, SQLDataNode)
  98. assert dn.storage_type() == "sql"
  99. assert dn.config_id == "foo_bar"
  100. assert dn.scope == Scope.SCENARIO
  101. assert dn.id is not None
  102. assert dn.owner_id is None
  103. assert dn.job_ids == []
  104. assert dn.is_ready_for_reading
  105. assert dn.properties["exposed_type"] == "pandas"
  106. assert dn.properties["read_query"] == "SELECT * FROM example"
  107. assert dn.properties["write_query_builder"] == my_write_query_builder_with_pandas
  108. sql_dn_config_1 = Config.configure_sql_data_node(
  109. id="foo",
  110. **properties,
  111. append_query_builder=my_append_query_builder_with_pandas,
  112. exposed_type=MyCustomObject,
  113. )
  114. dn_1 = _DataManagerFactory._build_manager()._create_and_set(sql_dn_config_1, None, None)
  115. assert isinstance(dn, SQLDataNode)
  116. assert dn_1.properties["exposed_type"] == MyCustomObject
  117. assert dn_1.properties["append_query_builder"] == my_append_query_builder_with_pandas
  118. @pytest.mark.parametrize("properties", __sql_properties)
  119. def test_get_user_properties(self, properties):
  120. custom_properties = properties.copy()
  121. custom_properties["foo"] = "bar"
  122. dn = SQLDataNode(
  123. "foo_bar",
  124. Scope.SCENARIO,
  125. properties=custom_properties,
  126. )
  127. assert dn._get_user_properties() == {"foo": "bar"}
  128. @pytest.mark.parametrize(
  129. "properties",
  130. [
  131. {},
  132. {"read_query": "ready query"},
  133. {"read_query": "ready query", "write_query_builder": "write query"},
  134. {"read_query": "ready query", "write_query_builder": "write query", "db_username": "foo"},
  135. {
  136. "read_query": "ready query",
  137. "write_query_builder": "write query",
  138. "db_username": "foo",
  139. "db_password": "foo",
  140. },
  141. {
  142. "read_query": "ready query",
  143. "write_query_builder": "write query",
  144. "db_username": "foo",
  145. "db_password": "foo",
  146. "db_name": "foo",
  147. },
  148. {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "some engine"},
  149. {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "sqlite"},
  150. {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "mssql", "db_name": "foo"},
  151. {
  152. "read_query": "ready query",
  153. "write_query_builder": "write query",
  154. "db_engine": "mysql",
  155. "db_username": "foo",
  156. },
  157. {
  158. "read_query": "ready query",
  159. "write_query_builder": "write query",
  160. "db_engine": "postgresql",
  161. "db_username": "foo",
  162. "db_password": "foo",
  163. },
  164. ],
  165. )
  166. def test_create_with_missing_parameters(self, properties):
  167. with pytest.raises(MissingRequiredProperty):
  168. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"))
  169. engine = properties.get("db_engine")
  170. if engine is not None and engine not in ["sqlite", "mssql", "mysql", "postgresql"]:
  171. with pytest.raises(UnknownDatabaseEngine):
  172. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
  173. else:
  174. with pytest.raises(MissingRequiredProperty):
  175. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
  176. @pytest.mark.parametrize("properties", __sql_properties)
  177. def test_write_query_builder(self, properties):
  178. custom_properties = properties.copy()
  179. custom_properties.pop("db_extra_args")
  180. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  181. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  182. # mock connection execute
  183. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  184. assert len(engine_mock.mock_calls[4].args) == 1
  185. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  186. assert len(engine_mock.mock_calls[5].args) == 2
  187. assert engine_mock.mock_calls[5].args[0].text == "INSERT INTO example VALUES (:foo, :bar)"
  188. assert engine_mock.mock_calls[5].args[1] == [
  189. {"foo": 1, "bar": 4},
  190. {"foo": 2, "bar": 5},
  191. {"foo": 3, "bar": 6},
  192. ]
  193. custom_properties["write_query_builder"] = single_write_query_builder
  194. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  195. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  196. # mock connection execute
  197. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  198. assert len(engine_mock.mock_calls[4].args) == 1
  199. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  200. @pytest.mark.parametrize(
  201. "tmp_sqlite_path",
  202. [
  203. "tmp_sqlite_db_file_path",
  204. "tmp_sqlite_sqlite3_file_path",
  205. ],
  206. )
  207. def test_sqlite_read_file_with_different_extension(self, tmp_sqlite_path, request):
  208. tmp_sqlite_path = request.getfixturevalue(tmp_sqlite_path)
  209. folder_path, db_name, file_extension = tmp_sqlite_path
  210. properties = {
  211. "db_engine": "sqlite",
  212. "read_query": "SELECT * from example",
  213. "write_query_builder": single_write_query_builder,
  214. "db_name": db_name,
  215. "sqlite_folder_path": folder_path,
  216. "sqlite_file_extension": file_extension,
  217. }
  218. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  219. data = dn.read()
  220. assert data.equals(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  221. def test_sqlite_append_pandas(self, tmp_sqlite_sqlite3_file_path):
  222. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  223. properties = {
  224. "db_engine": "sqlite",
  225. "read_query": "SELECT * FROM example",
  226. "write_query_builder": my_write_query_builder_with_pandas,
  227. "append_query_builder": my_append_query_builder_with_pandas,
  228. "db_name": db_name,
  229. "sqlite_folder_path": folder_path,
  230. "sqlite_file_extension": file_extension,
  231. }
  232. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  233. original_data = pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}])
  234. data = dn.read()
  235. assert_frame_equal(data, original_data)
  236. append_data_1 = pd.DataFrame([{"foo": 5, "bar": 6}, {"foo": 7, "bar": 8}])
  237. dn.append(append_data_1)
  238. assert_frame_equal(dn.read(), pd.concat([original_data, append_data_1]).reset_index(drop=True))
  239. def test_sqlite_append_without_append_query_builder(self, tmp_sqlite_sqlite3_file_path):
  240. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  241. properties = {
  242. "db_engine": "sqlite",
  243. "read_query": "SELECT * FROM example",
  244. "write_query_builder": my_write_query_builder_with_pandas,
  245. "db_name": db_name,
  246. "sqlite_folder_path": folder_path,
  247. "sqlite_file_extension": file_extension,
  248. }
  249. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  250. with pytest.raises(MissingAppendQueryBuilder):
  251. dn.append(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  252. def test_filter_pandas_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  253. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  254. properties = {
  255. "db_engine": "sqlite",
  256. "read_query": "SELECT * FROM example",
  257. "write_query_builder": my_write_query_builder_with_pandas,
  258. "db_name": db_name,
  259. "sqlite_folder_path": folder_path,
  260. "sqlite_file_extension": file_extension,
  261. "exposed_type": "pandas",
  262. }
  263. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  264. dn.write(
  265. pd.DataFrame(
  266. [
  267. {"foo": 1, "bar": 1},
  268. {"foo": 1, "bar": 2},
  269. {"foo": 1, "bar": 3},
  270. {"foo": 2, "bar": 1},
  271. {"foo": 2, "bar": 2},
  272. {"foo": 2, "bar": 3},
  273. ]
  274. )
  275. )
  276. # Test datanode indexing and slicing
  277. assert dn["foo"].equals(pd.Series([1, 1, 1, 2, 2, 2]))
  278. assert dn["bar"].equals(pd.Series([1, 2, 3, 1, 2, 3]))
  279. assert dn[:2].equals(pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}]))
  280. # Test filter data
  281. filtered_by_filter_method = dn.filter(("foo", 1, Operator.EQUAL))
  282. filtered_by_indexing = dn[dn["foo"] == 1]
  283. expected_data = pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}, {"foo": 1, "bar": 3}])
  284. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  285. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  286. filtered_by_filter_method = dn.filter(("foo", 1, Operator.NOT_EQUAL))
  287. filtered_by_indexing = dn[dn["foo"] != 1]
  288. expected_data = pd.DataFrame([{"foo": 2, "bar": 1}, {"foo": 2, "bar": 2}, {"foo": 2, "bar": 3}])
  289. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  290. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  291. filtered_by_filter_method = dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  292. filtered_by_indexing = dn[(dn["bar"] == 1) | (dn["bar"] == 2)]
  293. expected_data = pd.DataFrame(
  294. [
  295. {"foo": 1, "bar": 1},
  296. {"foo": 1, "bar": 2},
  297. {"foo": 2, "bar": 1},
  298. {"foo": 2, "bar": 2},
  299. ]
  300. )
  301. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  302. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  303. def test_filter_numpy_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  304. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  305. properties = {
  306. "db_engine": "sqlite",
  307. "read_query": "SELECT * FROM example",
  308. "write_query_builder": my_write_query_builder_with_pandas,
  309. "db_name": db_name,
  310. "sqlite_folder_path": folder_path,
  311. "sqlite_file_extension": file_extension,
  312. "exposed_type": "numpy",
  313. }
  314. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  315. dn.write(
  316. pd.DataFrame(
  317. [
  318. {"foo": 1, "bar": 1},
  319. {"foo": 1, "bar": 2},
  320. {"foo": 1, "bar": 3},
  321. {"foo": 2, "bar": 1},
  322. {"foo": 2, "bar": 2},
  323. {"foo": 2, "bar": 3},
  324. ]
  325. )
  326. )
  327. # Test datanode indexing and slicing
  328. assert np.array_equal(dn[0], np.array([1, 1]))
  329. assert np.array_equal(dn[1], np.array([1, 2]))
  330. assert np.array_equal(dn[:3], np.array([[1, 1], [1, 2], [1, 3]]))
  331. assert np.array_equal(dn[:, 0], np.array([1, 1, 1, 2, 2, 2]))
  332. assert np.array_equal(dn[1:4, :1], np.array([[1], [1], [2]]))
  333. # Test filter data
  334. assert np.array_equal(dn.filter(("foo", 1, Operator.EQUAL)), np.array([[1, 1], [1, 2], [1, 3]]))
  335. assert np.array_equal(dn[dn[:, 0] == 1], np.array([[1, 1], [1, 2], [1, 3]]))
  336. assert np.array_equal(dn.filter(("foo", 1, Operator.NOT_EQUAL)), np.array([[2, 1], [2, 2], [2, 3]]))
  337. assert np.array_equal(dn[dn[:, 0] != 1], np.array([[2, 1], [2, 2], [2, 3]]))
  338. assert np.array_equal(dn.filter(("bar", 2, Operator.EQUAL)), np.array([[1, 2], [2, 2]]))
  339. assert np.array_equal(dn[dn[:, 1] == 2], np.array([[1, 2], [2, 2]]))
  340. assert np.array_equal(
  341. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR),
  342. np.array([[1, 1], [1, 2], [2, 1], [2, 2]]),
  343. )
  344. assert np.array_equal(dn[(dn[:, 1] == 1) | (dn[:, 1] == 2)], np.array([[1, 1], [1, 2], [2, 1], [2, 2]]))
  345. def test_filter_does_not_read_all_entities(self, tmp_sqlite_sqlite3_file_path):
  346. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  347. properties = {
  348. "db_engine": "sqlite",
  349. "read_query": "SELECT * FROM example",
  350. "write_query_builder": my_write_query_builder_with_pandas,
  351. "db_name": db_name,
  352. "sqlite_folder_path": folder_path,
  353. "sqlite_file_extension": file_extension,
  354. "exposed_type": "numpy",
  355. }
  356. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  357. # SQLDataNode.filter() should not call the MongoCollectionDataNode._read() method
  358. with patch.object(SQLDataNode, "_read") as read_mock:
  359. dn.filter(("foo", 1, Operator.EQUAL))
  360. dn.filter(("bar", 2, Operator.NOT_EQUAL))
  361. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  362. assert read_mock["_read"].call_count == 0