test_sql_data_node.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. from importlib import util
  12. from unittest.mock import patch
  13. import numpy as np
  14. import pandas as pd
  15. import pytest
  16. from pandas.testing import assert_frame_equal
  17. from taipy.config.common.scope import Scope
  18. from taipy.core.data.data_node_id import DataNodeId
  19. from taipy.core.data.operator import JoinOperator, Operator
  20. from taipy.core.data.sql import SQLDataNode
  21. from taipy.core.exceptions.exceptions import MissingAppendQueryBuilder, MissingRequiredProperty
  22. class MyCustomObject:
  23. def __init__(self, foo=None, bar=None, *args, **kwargs):
  24. self.foo = foo
  25. self.bar = bar
  26. self.args = args
  27. self.kwargs = kwargs
  28. def my_write_query_builder_with_pandas(data: pd.DataFrame):
  29. insert_data = data.to_dict("records")
  30. return ["DELETE FROM example", ("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  31. def my_append_query_builder_with_pandas(data: pd.DataFrame):
  32. insert_data = data.to_dict("records")
  33. return [("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  34. def single_write_query_builder(data):
  35. return "DELETE FROM example"
  36. class TestSQLDataNode:
  37. __pandas_properties = [
  38. {
  39. "db_name": "taipy.sqlite3",
  40. "db_engine": "sqlite",
  41. "read_query": "SELECT * FROM example",
  42. "write_query_builder": my_write_query_builder_with_pandas,
  43. "db_extra_args": {
  44. "TrustServerCertificate": "yes",
  45. "other": "value",
  46. },
  47. },
  48. ]
  49. if util.find_spec("pyodbc"):
  50. __pandas_properties.append(
  51. {
  52. "db_username": "sa",
  53. "db_password": "Passw0rd",
  54. "db_name": "taipy",
  55. "db_engine": "mssql",
  56. "read_query": "SELECT * FROM example",
  57. "write_query_builder": my_write_query_builder_with_pandas,
  58. "db_extra_args": {
  59. "TrustServerCertificate": "yes",
  60. },
  61. },
  62. )
  63. if util.find_spec("pymysql"):
  64. __pandas_properties.append(
  65. {
  66. "db_username": "sa",
  67. "db_password": "Passw0rd",
  68. "db_name": "taipy",
  69. "db_engine": "mysql",
  70. "read_query": "SELECT * FROM example",
  71. "write_query_builder": my_write_query_builder_with_pandas,
  72. "db_extra_args": {
  73. "TrustServerCertificate": "yes",
  74. },
  75. },
  76. )
  77. if util.find_spec("psycopg2"):
  78. __pandas_properties.append(
  79. {
  80. "db_username": "sa",
  81. "db_password": "Passw0rd",
  82. "db_name": "taipy",
  83. "db_engine": "postgresql",
  84. "read_query": "SELECT * FROM example",
  85. "write_query_builder": my_write_query_builder_with_pandas,
  86. "db_extra_args": {
  87. "TrustServerCertificate": "yes",
  88. },
  89. },
  90. )
  91. @pytest.mark.parametrize("pandas_properties", __pandas_properties)
  92. def test_create(self, pandas_properties):
  93. dn = SQLDataNode(
  94. "foo_bar",
  95. Scope.SCENARIO,
  96. properties=pandas_properties,
  97. )
  98. assert isinstance(dn, SQLDataNode)
  99. assert dn.storage_type() == "sql"
  100. assert dn.config_id == "foo_bar"
  101. assert dn.scope == Scope.SCENARIO
  102. assert dn.id is not None
  103. assert dn.owner_id is None
  104. assert dn.job_ids == []
  105. assert dn.is_ready_for_reading
  106. assert dn.exposed_type == "pandas"
  107. assert dn.read_query == "SELECT * FROM example"
  108. assert dn.write_query_builder == my_write_query_builder_with_pandas
  109. @pytest.mark.parametrize("properties", __pandas_properties)
  110. def test_get_user_properties(self, properties):
  111. custom_properties = properties.copy()
  112. custom_properties["foo"] = "bar"
  113. dn = SQLDataNode(
  114. "foo_bar",
  115. Scope.SCENARIO,
  116. properties=custom_properties,
  117. )
  118. assert dn._get_user_properties() == {"foo": "bar"}
  119. @pytest.mark.parametrize(
  120. "properties",
  121. [
  122. {},
  123. {"db_username": "foo"},
  124. {"db_username": "foo", "db_password": "foo"},
  125. {"db_username": "foo", "db_password": "foo", "db_name": "foo"},
  126. {"engine": "sqlite"},
  127. {"engine": "mssql", "db_name": "foo"},
  128. {"engine": "mysql", "db_username": "foo"},
  129. {"engine": "postgresql", "db_username": "foo", "db_password": "foo"},
  130. ],
  131. )
  132. def test_create_with_missing_parameters(self, properties):
  133. with pytest.raises(MissingRequiredProperty):
  134. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"))
  135. with pytest.raises(MissingRequiredProperty):
  136. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
  137. @pytest.mark.parametrize("pandas_properties", __pandas_properties)
  138. def test_write_query_builder(self, pandas_properties):
  139. custom_properties = pandas_properties.copy()
  140. custom_properties.pop("db_extra_args")
  141. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  142. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  143. # mock connection execute
  144. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  145. assert len(engine_mock.mock_calls[4].args) == 1
  146. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  147. assert len(engine_mock.mock_calls[5].args) == 2
  148. assert engine_mock.mock_calls[5].args[0].text == "INSERT INTO example VALUES (:foo, :bar)"
  149. assert engine_mock.mock_calls[5].args[1] == [
  150. {"foo": 1, "bar": 4},
  151. {"foo": 2, "bar": 5},
  152. {"foo": 3, "bar": 6},
  153. ]
  154. custom_properties["write_query_builder"] = single_write_query_builder
  155. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  156. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  157. # mock connection execute
  158. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  159. assert len(engine_mock.mock_calls[4].args) == 1
  160. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  161. @pytest.mark.parametrize(
  162. "tmp_sqlite_path",
  163. [
  164. "tmp_sqlite_db_file_path",
  165. "tmp_sqlite_sqlite3_file_path",
  166. ],
  167. )
  168. def test_sqlite_read_file_with_different_extension(self, tmp_sqlite_path, request):
  169. tmp_sqlite_path = request.getfixturevalue(tmp_sqlite_path)
  170. folder_path, db_name, file_extension = tmp_sqlite_path
  171. properties = {
  172. "db_engine": "sqlite",
  173. "read_query": "SELECT * from example",
  174. "write_query_builder": single_write_query_builder,
  175. "db_name": db_name,
  176. "sqlite_folder_path": folder_path,
  177. "sqlite_file_extension": file_extension,
  178. }
  179. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  180. data = dn.read()
  181. assert data.equals(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  182. def test_sqlite_append_pandas(self, tmp_sqlite_sqlite3_file_path):
  183. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  184. properties = {
  185. "db_engine": "sqlite",
  186. "read_query": "SELECT * FROM example",
  187. "write_query_builder": my_write_query_builder_with_pandas,
  188. "append_query_builder": my_append_query_builder_with_pandas,
  189. "db_name": db_name,
  190. "sqlite_folder_path": folder_path,
  191. "sqlite_file_extension": file_extension,
  192. }
  193. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  194. original_data = pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}])
  195. data = dn.read()
  196. assert_frame_equal(data, original_data)
  197. append_data_1 = pd.DataFrame([{"foo": 5, "bar": 6}, {"foo": 7, "bar": 8}])
  198. dn.append(append_data_1)
  199. assert_frame_equal(dn.read(), pd.concat([original_data, append_data_1]).reset_index(drop=True))
  200. def test_sqlite_append_without_append_query_builder(self, tmp_sqlite_sqlite3_file_path):
  201. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  202. properties = {
  203. "db_engine": "sqlite",
  204. "read_query": "SELECT * FROM example",
  205. "write_query_builder": my_write_query_builder_with_pandas,
  206. "db_name": db_name,
  207. "sqlite_folder_path": folder_path,
  208. "sqlite_file_extension": file_extension,
  209. }
  210. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  211. with pytest.raises(MissingAppendQueryBuilder):
  212. dn.append(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  213. def test_filter_pandas_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  214. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  215. properties = {
  216. "db_engine": "sqlite",
  217. "read_query": "SELECT * FROM example",
  218. "write_query_builder": my_write_query_builder_with_pandas,
  219. "db_name": db_name,
  220. "sqlite_folder_path": folder_path,
  221. "sqlite_file_extension": file_extension,
  222. "exposed_type": "pandas",
  223. }
  224. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  225. dn.write(
  226. pd.DataFrame(
  227. [
  228. {"foo": 1, "bar": 1},
  229. {"foo": 1, "bar": 2},
  230. {"foo": 1, "bar": 3},
  231. {"foo": 2, "bar": 1},
  232. {"foo": 2, "bar": 2},
  233. {"foo": 2, "bar": 3},
  234. ]
  235. )
  236. )
  237. # Test datanode indexing and slicing
  238. assert dn["foo"].equals(pd.Series([1, 1, 1, 2, 2, 2]))
  239. assert dn["bar"].equals(pd.Series([1, 2, 3, 1, 2, 3]))
  240. assert dn[:2].equals(pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}]))
  241. # Test filter data
  242. filtered_by_filter_method = dn.filter(("foo", 1, Operator.EQUAL))
  243. filtered_by_indexing = dn[dn["foo"] == 1]
  244. expected_data = pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}, {"foo": 1, "bar": 3}])
  245. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  246. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  247. filtered_by_filter_method = dn.filter(("foo", 1, Operator.NOT_EQUAL))
  248. filtered_by_indexing = dn[dn["foo"] != 1]
  249. expected_data = pd.DataFrame([{"foo": 2, "bar": 1}, {"foo": 2, "bar": 2}, {"foo": 2, "bar": 3}])
  250. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  251. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  252. filtered_by_filter_method = dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  253. filtered_by_indexing = dn[(dn["bar"] == 1) | (dn["bar"] == 2)]
  254. expected_data = pd.DataFrame(
  255. [
  256. {"foo": 1, "bar": 1},
  257. {"foo": 1, "bar": 2},
  258. {"foo": 2, "bar": 1},
  259. {"foo": 2, "bar": 2},
  260. ]
  261. )
  262. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  263. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  264. def test_filter_numpy_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  265. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  266. properties = {
  267. "db_engine": "sqlite",
  268. "read_query": "SELECT * FROM example",
  269. "write_query_builder": my_write_query_builder_with_pandas,
  270. "db_name": db_name,
  271. "sqlite_folder_path": folder_path,
  272. "sqlite_file_extension": file_extension,
  273. "exposed_type": "numpy",
  274. }
  275. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  276. dn.write(
  277. pd.DataFrame(
  278. [
  279. {"foo": 1, "bar": 1},
  280. {"foo": 1, "bar": 2},
  281. {"foo": 1, "bar": 3},
  282. {"foo": 2, "bar": 1},
  283. {"foo": 2, "bar": 2},
  284. {"foo": 2, "bar": 3},
  285. ]
  286. )
  287. )
  288. # Test datanode indexing and slicing
  289. assert np.array_equal(dn[0], np.array([1, 1]))
  290. assert np.array_equal(dn[1], np.array([1, 2]))
  291. assert np.array_equal(dn[:3], np.array([[1, 1], [1, 2], [1, 3]]))
  292. assert np.array_equal(dn[:, 0], np.array([1, 1, 1, 2, 2, 2]))
  293. assert np.array_equal(dn[1:4, :1], np.array([[1], [1], [2]]))
  294. # Test filter data
  295. assert np.array_equal(dn.filter(("foo", 1, Operator.EQUAL)), np.array([[1, 1], [1, 2], [1, 3]]))
  296. assert np.array_equal(dn[dn[:, 0] == 1], np.array([[1, 1], [1, 2], [1, 3]]))
  297. assert np.array_equal(dn.filter(("foo", 1, Operator.NOT_EQUAL)), np.array([[2, 1], [2, 2], [2, 3]]))
  298. assert np.array_equal(dn[dn[:, 0] != 1], np.array([[2, 1], [2, 2], [2, 3]]))
  299. assert np.array_equal(dn.filter(("bar", 2, Operator.EQUAL)), np.array([[1, 2], [2, 2]]))
  300. assert np.array_equal(dn[dn[:, 1] == 2], np.array([[1, 2], [2, 2]]))
  301. assert np.array_equal(
  302. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR),
  303. np.array([[1, 1], [1, 2], [2, 1], [2, 2]]),
  304. )
  305. assert np.array_equal(dn[(dn[:, 1] == 1) | (dn[:, 1] == 2)], np.array([[1, 1], [1, 2], [2, 1], [2, 2]]))
  306. def test_filter_does_not_read_all_entities(self, tmp_sqlite_sqlite3_file_path):
  307. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  308. properties = {
  309. "db_engine": "sqlite",
  310. "read_query": "SELECT * FROM example",
  311. "write_query_builder": my_write_query_builder_with_pandas,
  312. "db_name": db_name,
  313. "sqlite_folder_path": folder_path,
  314. "sqlite_file_extension": file_extension,
  315. "exposed_type": "numpy",
  316. }
  317. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  318. # SQLDataNode.filter() should not call the MongoCollectionDataNode._read() method
  319. with patch.object(SQLDataNode, "_read") as read_mock:
  320. dn.filter(("foo", 1, Operator.EQUAL))
  321. dn.filter(("bar", 2, Operator.NOT_EQUAL))
  322. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  323. assert read_mock["_read"].call_count == 0