test_sql_data_node.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. # Copyright 2023 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. from importlib import util
  12. from unittest.mock import patch
  13. import modin.pandas as modin_pd
  14. import numpy as np
  15. import pandas as pd
  16. import pytest
  17. from modin.pandas.test.utils import df_equals
  18. from pandas.testing import assert_frame_equal
  19. from taipy.config.common.scope import Scope
  20. from taipy.core.data.data_node_id import DataNodeId
  21. from taipy.core.data.operator import JoinOperator, Operator
  22. from taipy.core.data.sql import SQLDataNode
  23. from taipy.core.exceptions.exceptions import MissingAppendQueryBuilder, MissingRequiredProperty
  24. class MyCustomObject:
  25. def __init__(self, foo=None, bar=None, *args, **kwargs):
  26. self.foo = foo
  27. self.bar = bar
  28. self.args = args
  29. self.kwargs = kwargs
  30. def my_write_query_builder_with_pandas(data: pd.DataFrame):
  31. insert_data = data.to_dict("records")
  32. return ["DELETE FROM example", ("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  33. def my_write_query_builder_with_modin(data: modin_pd.DataFrame):
  34. insert_data = data.to_dict("records")
  35. return ["DELETE FROM example", ("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  36. def my_append_query_builder_with_pandas(data: pd.DataFrame):
  37. insert_data = data.to_dict("records")
  38. return [("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  39. def my_append_query_builder_with_modin(data: modin_pd.DataFrame):
  40. insert_data = data.to_dict("records")
  41. return [("INSERT INTO example VALUES (:foo, :bar)", insert_data)]
  42. def single_write_query_builder(data):
  43. return "DELETE FROM example"
  44. class TestSQLDataNode:
  45. __pandas_properties = [
  46. {
  47. "db_name": "taipy.sqlite3",
  48. "db_engine": "sqlite",
  49. "read_query": "SELECT * FROM example",
  50. "write_query_builder": my_write_query_builder_with_pandas,
  51. "db_extra_args": {
  52. "TrustServerCertificate": "yes",
  53. "other": "value",
  54. },
  55. },
  56. ]
  57. __modin_properties = [
  58. {
  59. "db_name": "taipy.sqlite3",
  60. "db_engine": "sqlite",
  61. "read_query": "SELECT * FROM example",
  62. "write_query_builder": my_write_query_builder_with_modin,
  63. "exposed_type": "modin",
  64. "db_extra_args": {
  65. "TrustServerCertificate": "yes",
  66. "other": "value",
  67. },
  68. },
  69. ]
  70. if util.find_spec("pyodbc"):
  71. __pandas_properties.append(
  72. {
  73. "db_username": "sa",
  74. "db_password": "Passw0rd",
  75. "db_name": "taipy",
  76. "db_engine": "mssql",
  77. "read_query": "SELECT * FROM example",
  78. "write_query_builder": my_write_query_builder_with_pandas,
  79. "db_extra_args": {
  80. "TrustServerCertificate": "yes",
  81. },
  82. },
  83. )
  84. __modin_properties.append(
  85. {
  86. "db_username": "sa",
  87. "db_password": "Passw0rd",
  88. "db_name": "taipy",
  89. "db_engine": "mssql",
  90. "read_query": "SELECT * FROM example",
  91. "write_query_builder": my_write_query_builder_with_modin,
  92. "exposed_type": "modin",
  93. "db_extra_args": {
  94. "TrustServerCertificate": "yes",
  95. },
  96. },
  97. )
  98. if util.find_spec("pymysql"):
  99. __pandas_properties.append(
  100. {
  101. "db_username": "sa",
  102. "db_password": "Passw0rd",
  103. "db_name": "taipy",
  104. "db_engine": "mysql",
  105. "read_query": "SELECT * FROM example",
  106. "write_query_builder": my_write_query_builder_with_pandas,
  107. "db_extra_args": {
  108. "TrustServerCertificate": "yes",
  109. },
  110. },
  111. )
  112. __modin_properties.append(
  113. {
  114. "db_username": "sa",
  115. "db_password": "Passw0rd",
  116. "db_name": "taipy",
  117. "db_engine": "mysql",
  118. "read_query": "SELECT * FROM example",
  119. "write_query_builder": my_write_query_builder_with_modin,
  120. "exposed_type": "modin",
  121. "db_extra_args": {
  122. "TrustServerCertificate": "yes",
  123. },
  124. },
  125. )
  126. if util.find_spec("psycopg2"):
  127. __pandas_properties.append(
  128. {
  129. "db_username": "sa",
  130. "db_password": "Passw0rd",
  131. "db_name": "taipy",
  132. "db_engine": "postgresql",
  133. "read_query": "SELECT * FROM example",
  134. "write_query_builder": my_write_query_builder_with_pandas,
  135. "db_extra_args": {
  136. "TrustServerCertificate": "yes",
  137. },
  138. },
  139. )
  140. __modin_properties.append(
  141. {
  142. "db_username": "sa",
  143. "db_password": "Passw0rd",
  144. "db_name": "taipy",
  145. "db_engine": "postgresql",
  146. "read_query": "SELECT * FROM example",
  147. "write_query_builder": my_write_query_builder_with_modin,
  148. "exposed_type": "modin",
  149. "db_extra_args": {
  150. "TrustServerCertificate": "yes",
  151. },
  152. },
  153. )
  154. @pytest.mark.parametrize("pandas_properties", __pandas_properties)
  155. @pytest.mark.parametrize("modin_properties", __modin_properties)
  156. def test_create(self, pandas_properties, modin_properties):
  157. dn = SQLDataNode(
  158. "foo_bar",
  159. Scope.SCENARIO,
  160. properties=pandas_properties,
  161. )
  162. assert isinstance(dn, SQLDataNode)
  163. assert dn.storage_type() == "sql"
  164. assert dn.config_id == "foo_bar"
  165. assert dn.scope == Scope.SCENARIO
  166. assert dn.id is not None
  167. assert dn.owner_id is None
  168. assert dn.job_ids == []
  169. assert dn.is_ready_for_reading
  170. assert dn.exposed_type == "pandas"
  171. assert dn.read_query == "SELECT * FROM example"
  172. assert dn.write_query_builder == my_write_query_builder_with_pandas
  173. dn = SQLDataNode(
  174. "foo_bar",
  175. Scope.SCENARIO,
  176. properties=modin_properties,
  177. )
  178. assert isinstance(dn, SQLDataNode)
  179. assert dn.storage_type() == "sql"
  180. assert dn.config_id == "foo_bar"
  181. assert dn.scope == Scope.SCENARIO
  182. assert dn.id is not None
  183. assert dn.owner_id is None
  184. assert dn.job_ids == []
  185. assert dn.is_ready_for_reading
  186. assert dn.exposed_type == "modin"
  187. assert dn.read_query == "SELECT * FROM example"
  188. assert dn.write_query_builder == my_write_query_builder_with_modin
  189. @pytest.mark.parametrize("properties", __pandas_properties + __modin_properties)
  190. def test_get_user_properties(self, properties):
  191. custom_properties = properties.copy()
  192. custom_properties["foo"] = "bar"
  193. dn = SQLDataNode(
  194. "foo_bar",
  195. Scope.SCENARIO,
  196. properties=custom_properties,
  197. )
  198. assert dn._get_user_properties() == {"foo": "bar"}
  199. @pytest.mark.parametrize(
  200. "properties",
  201. [
  202. {},
  203. {"db_username": "foo"},
  204. {"db_username": "foo", "db_password": "foo"},
  205. {"db_username": "foo", "db_password": "foo", "db_name": "foo"},
  206. {"engine": "sqlite"},
  207. {"engine": "mssql", "db_name": "foo"},
  208. {"engine": "mysql", "db_username": "foo"},
  209. {"engine": "postgresql", "db_username": "foo", "db_password": "foo"},
  210. ],
  211. )
  212. def test_create_with_missing_parameters(self, properties):
  213. with pytest.raises(MissingRequiredProperty):
  214. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"))
  215. with pytest.raises(MissingRequiredProperty):
  216. SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
  217. @pytest.mark.parametrize("pandas_properties", __pandas_properties)
  218. @pytest.mark.parametrize("modin_properties", __modin_properties)
  219. def test_write_query_builder(self, pandas_properties, modin_properties):
  220. custom_properties = pandas_properties.copy()
  221. custom_properties.pop("db_extra_args")
  222. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  223. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  224. # mock connection execute
  225. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  226. assert len(engine_mock.mock_calls[4].args) == 1
  227. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  228. assert len(engine_mock.mock_calls[5].args) == 2
  229. assert engine_mock.mock_calls[5].args[0].text == "INSERT INTO example VALUES (:foo, :bar)"
  230. assert engine_mock.mock_calls[5].args[1] == [
  231. {"foo": 1, "bar": 4},
  232. {"foo": 2, "bar": 5},
  233. {"foo": 3, "bar": 6},
  234. ]
  235. custom_properties["write_query_builder"] = single_write_query_builder
  236. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  237. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  238. # mock connection execute
  239. dn.write(pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  240. assert len(engine_mock.mock_calls[4].args) == 1
  241. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  242. custom_properties = modin_properties.copy()
  243. custom_properties.pop("db_extra_args")
  244. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  245. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  246. # mock connection execute
  247. dn.write(modin_pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  248. assert len(engine_mock.mock_calls[4].args) == 1
  249. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  250. assert len(engine_mock.mock_calls[5].args) == 2
  251. assert engine_mock.mock_calls[5].args[0].text == "INSERT INTO example VALUES (:foo, :bar)"
  252. assert engine_mock.mock_calls[5].args[1] == [
  253. {"foo": 1, "bar": 4},
  254. {"foo": 2, "bar": 5},
  255. {"foo": 3, "bar": 6},
  256. ]
  257. custom_properties["write_query_builder"] = single_write_query_builder
  258. dn = SQLDataNode("foo_bar", Scope.SCENARIO, properties=custom_properties)
  259. with patch("sqlalchemy.engine.Engine.connect") as engine_mock:
  260. # mock connection execute
  261. dn.write(modin_pd.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}))
  262. assert len(engine_mock.mock_calls[4].args) == 1
  263. assert engine_mock.mock_calls[4].args[0].text == "DELETE FROM example"
  264. @pytest.mark.parametrize(
  265. "tmp_sqlite_path",
  266. [
  267. "tmp_sqlite_db_file_path",
  268. "tmp_sqlite_sqlite3_file_path",
  269. ],
  270. )
  271. def test_sqlite_read_file_with_different_extension(self, tmp_sqlite_path, request):
  272. tmp_sqlite_path = request.getfixturevalue(tmp_sqlite_path)
  273. folder_path, db_name, file_extension = tmp_sqlite_path
  274. properties = {
  275. "db_engine": "sqlite",
  276. "read_query": "SELECT * from example",
  277. "write_query_builder": single_write_query_builder,
  278. "db_name": db_name,
  279. "sqlite_folder_path": folder_path,
  280. "sqlite_file_extension": file_extension,
  281. }
  282. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  283. data = dn.read()
  284. assert data.equals(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  285. def test_sqlite_append_pandas(self, tmp_sqlite_sqlite3_file_path):
  286. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  287. properties = {
  288. "db_engine": "sqlite",
  289. "read_query": "SELECT * FROM example",
  290. "write_query_builder": my_write_query_builder_with_pandas,
  291. "append_query_builder": my_append_query_builder_with_pandas,
  292. "db_name": db_name,
  293. "sqlite_folder_path": folder_path,
  294. "sqlite_file_extension": file_extension,
  295. }
  296. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  297. original_data = pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}])
  298. data = dn.read()
  299. assert_frame_equal(data, original_data)
  300. append_data_1 = pd.DataFrame([{"foo": 5, "bar": 6}, {"foo": 7, "bar": 8}])
  301. dn.append(append_data_1)
  302. assert_frame_equal(dn.read(), pd.concat([original_data, append_data_1]).reset_index(drop=True))
  303. @pytest.mark.modin
  304. def test_sqlite_append_modin(self, tmp_sqlite_sqlite3_file_path):
  305. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  306. properties = {
  307. "db_engine": "sqlite",
  308. "read_query": "SELECT * FROM example",
  309. "write_query_builder": my_write_query_builder_with_pandas,
  310. "append_query_builder": my_append_query_builder_with_pandas,
  311. "db_name": db_name,
  312. "sqlite_folder_path": folder_path,
  313. "sqlite_file_extension": file_extension,
  314. "exposed_type": "modin",
  315. }
  316. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  317. original_data = modin_pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}])
  318. data = dn.read()
  319. df_equals(data, original_data)
  320. append_data_1 = modin_pd.DataFrame([{"foo": 5, "bar": 6}, {"foo": 7, "bar": 8}])
  321. dn.append(append_data_1)
  322. df_equals(dn.read(), modin_pd.concat([original_data, append_data_1]).reset_index(drop=True))
  323. def test_sqlite_append_without_append_query_builder(self, tmp_sqlite_sqlite3_file_path):
  324. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  325. properties = {
  326. "db_engine": "sqlite",
  327. "read_query": "SELECT * FROM example",
  328. "write_query_builder": my_write_query_builder_with_pandas,
  329. "db_name": db_name,
  330. "sqlite_folder_path": folder_path,
  331. "sqlite_file_extension": file_extension,
  332. }
  333. dn = SQLDataNode("sqlite_dn", Scope.SCENARIO, properties=properties)
  334. with pytest.raises(MissingAppendQueryBuilder):
  335. dn.append(pd.DataFrame([{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]))
  336. def test_filter_pandas_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  337. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  338. properties = {
  339. "db_engine": "sqlite",
  340. "read_query": "SELECT * FROM example",
  341. "write_query_builder": my_write_query_builder_with_pandas,
  342. "db_name": db_name,
  343. "sqlite_folder_path": folder_path,
  344. "sqlite_file_extension": file_extension,
  345. "exposed_type": "pandas",
  346. }
  347. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  348. dn.write(
  349. pd.DataFrame(
  350. [
  351. {"foo": 1, "bar": 1},
  352. {"foo": 1, "bar": 2},
  353. {"foo": 1, "bar": 3},
  354. {"foo": 2, "bar": 1},
  355. {"foo": 2, "bar": 2},
  356. {"foo": 2, "bar": 3},
  357. ]
  358. )
  359. )
  360. # Test datanode indexing and slicing
  361. assert dn["foo"].equals(pd.Series([1, 1, 1, 2, 2, 2]))
  362. assert dn["bar"].equals(pd.Series([1, 2, 3, 1, 2, 3]))
  363. assert dn[:2].equals(pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}]))
  364. # Test filter data
  365. filtered_by_filter_method = dn.filter(("foo", 1, Operator.EQUAL))
  366. filtered_by_indexing = dn[dn["foo"] == 1]
  367. expected_data = pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}, {"foo": 1, "bar": 3}])
  368. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  369. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  370. filtered_by_filter_method = dn.filter(("foo", 1, Operator.NOT_EQUAL))
  371. filtered_by_indexing = dn[dn["foo"] != 1]
  372. expected_data = pd.DataFrame([{"foo": 2, "bar": 1}, {"foo": 2, "bar": 2}, {"foo": 2, "bar": 3}])
  373. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  374. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  375. filtered_by_filter_method = dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  376. filtered_by_indexing = dn[(dn["bar"] == 1) | (dn["bar"] == 2)]
  377. expected_data = pd.DataFrame(
  378. [
  379. {"foo": 1, "bar": 1},
  380. {"foo": 1, "bar": 2},
  381. {"foo": 2, "bar": 1},
  382. {"foo": 2, "bar": 2},
  383. ]
  384. )
  385. assert_frame_equal(filtered_by_filter_method.reset_index(drop=True), expected_data)
  386. assert_frame_equal(filtered_by_indexing.reset_index(drop=True), expected_data)
  387. @pytest.mark.modin
  388. def test_filter_modin_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  389. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  390. properties = {
  391. "db_engine": "sqlite",
  392. "read_query": "SELECT * FROM example",
  393. "write_query_builder": my_write_query_builder_with_modin,
  394. "db_name": db_name,
  395. "sqlite_folder_path": folder_path,
  396. "sqlite_file_extension": file_extension,
  397. "exposed_type": "modin",
  398. }
  399. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  400. dn.write(
  401. pd.DataFrame(
  402. [
  403. {"foo": 1, "bar": 1},
  404. {"foo": 1, "bar": 2},
  405. {"foo": 1, "bar": 3},
  406. {"foo": 2, "bar": 1},
  407. {"foo": 2, "bar": 2},
  408. {"foo": 2, "bar": 3},
  409. ]
  410. )
  411. )
  412. # Test datanode indexing and slicing
  413. assert dn["foo"].equals(pd.Series([1, 1, 1, 2, 2, 2]))
  414. assert dn["bar"].equals(pd.Series([1, 2, 3, 1, 2, 3]))
  415. assert dn[:2].equals(modin_pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}]))
  416. # Test filter data
  417. filtered_by_filter_method = dn.filter(("foo", 1, Operator.EQUAL))
  418. filtered_by_indexing = dn[dn["foo"] == 1]
  419. expected_data = modin_pd.DataFrame([{"foo": 1, "bar": 1}, {"foo": 1, "bar": 2}, {"foo": 1, "bar": 3}])
  420. df_equals(filtered_by_filter_method.reset_index(drop=True), expected_data)
  421. df_equals(filtered_by_indexing.reset_index(drop=True), expected_data)
  422. filtered_by_filter_method = dn.filter(("foo", 1, Operator.NOT_EQUAL))
  423. filtered_by_indexing = dn[dn["foo"] != 1]
  424. expected_data = modin_pd.DataFrame([{"foo": 2, "bar": 1}, {"foo": 2, "bar": 2}, {"foo": 2, "bar": 3}])
  425. df_equals(filtered_by_filter_method.reset_index(drop=True), expected_data)
  426. df_equals(filtered_by_indexing.reset_index(drop=True), expected_data)
  427. filtered_by_filter_method = dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  428. filtered_by_indexing = dn[(dn["bar"] == 1) | (dn["bar"] == 2)]
  429. expected_data = modin_pd.DataFrame(
  430. [
  431. {"foo": 1, "bar": 1},
  432. {"foo": 1, "bar": 2},
  433. {"foo": 2, "bar": 1},
  434. {"foo": 2, "bar": 2},
  435. ]
  436. )
  437. df_equals(filtered_by_filter_method.reset_index(drop=True), expected_data)
  438. df_equals(filtered_by_indexing.reset_index(drop=True), expected_data)
  439. def test_filter_numpy_exposed_type(self, tmp_sqlite_sqlite3_file_path):
  440. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  441. properties = {
  442. "db_engine": "sqlite",
  443. "read_query": "SELECT * FROM example",
  444. "write_query_builder": my_write_query_builder_with_pandas,
  445. "db_name": db_name,
  446. "sqlite_folder_path": folder_path,
  447. "sqlite_file_extension": file_extension,
  448. "exposed_type": "numpy",
  449. }
  450. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  451. dn.write(
  452. pd.DataFrame(
  453. [
  454. {"foo": 1, "bar": 1},
  455. {"foo": 1, "bar": 2},
  456. {"foo": 1, "bar": 3},
  457. {"foo": 2, "bar": 1},
  458. {"foo": 2, "bar": 2},
  459. {"foo": 2, "bar": 3},
  460. ]
  461. )
  462. )
  463. # Test datanode indexing and slicing
  464. assert np.array_equal(dn[0], np.array([1, 1]))
  465. assert np.array_equal(dn[1], np.array([1, 2]))
  466. assert np.array_equal(dn[:3], np.array([[1, 1], [1, 2], [1, 3]]))
  467. assert np.array_equal(dn[:, 0], np.array([1, 1, 1, 2, 2, 2]))
  468. assert np.array_equal(dn[1:4, :1], np.array([[1], [1], [2]]))
  469. # Test filter data
  470. assert np.array_equal(dn.filter(("foo", 1, Operator.EQUAL)), np.array([[1, 1], [1, 2], [1, 3]]))
  471. assert np.array_equal(dn[dn[:, 0] == 1], np.array([[1, 1], [1, 2], [1, 3]]))
  472. assert np.array_equal(dn.filter(("foo", 1, Operator.NOT_EQUAL)), np.array([[2, 1], [2, 2], [2, 3]]))
  473. assert np.array_equal(dn[dn[:, 0] != 1], np.array([[2, 1], [2, 2], [2, 3]]))
  474. assert np.array_equal(dn.filter(("bar", 2, Operator.EQUAL)), np.array([[1, 2], [2, 2]]))
  475. assert np.array_equal(dn[dn[:, 1] == 2], np.array([[1, 2], [2, 2]]))
  476. assert np.array_equal(
  477. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR),
  478. np.array([[1, 1], [1, 2], [2, 1], [2, 2]]),
  479. )
  480. assert np.array_equal(dn[(dn[:, 1] == 1) | (dn[:, 1] == 2)], np.array([[1, 1], [1, 2], [2, 1], [2, 2]]))
  481. def test_filter_does_not_read_all_entities(self, tmp_sqlite_sqlite3_file_path):
  482. folder_path, db_name, file_extension = tmp_sqlite_sqlite3_file_path
  483. properties = {
  484. "db_engine": "sqlite",
  485. "read_query": "SELECT * FROM example",
  486. "write_query_builder": my_write_query_builder_with_pandas,
  487. "db_name": db_name,
  488. "sqlite_folder_path": folder_path,
  489. "sqlite_file_extension": file_extension,
  490. "exposed_type": "numpy",
  491. }
  492. dn = SQLDataNode("foo", Scope.SCENARIO, properties=properties)
  493. # SQLDataNode.filter() should not call the MongoCollectionDataNode._read() method
  494. with patch.object(SQLDataNode, "_read") as read_mock:
  495. dn.filter(("foo", 1, Operator.EQUAL))
  496. dn.filter(("bar", 2, Operator.NOT_EQUAL))
  497. dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)
  498. assert read_mock["_read"].call_count == 0