Browse Source

improve abstract sql coverage

Toan Quach 10 months ago
parent
commit
f0d820be2f

+ 0 - 1
taipy/core/data/_abstract_sql.py

@@ -190,7 +190,6 @@ class _AbstractSQLDataNode(DataNode, _TabularDataNodeMixin):
             folder_path = properties.get(self.__SQLITE_FOLDER_PATH, self.__SQLITE_FOLDER_PATH_DEFAULT)
             file_extension = properties.get(self.__SQLITE_FILE_EXTENSION, self.__SQLITE_FILE_EXTENSION_DEFAULT)
             return "sqlite:///" + os.path.join(folder_path, f"{db_name}{file_extension}")
-
         raise UnknownDatabaseEngine(f"Unknown engine: {engine}")
 
     def filter(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):

+ 98 - 0
tests/core/data/test_read_sql_table_data_node.py

@@ -17,6 +17,7 @@ import pandas as pd
 import pytest
 
 from taipy.config.common.scope import Scope
+from taipy.core.data.operator import JoinOperator, Operator
 from taipy.core.data.sql_table import SQLTableDataNode
 
 
@@ -105,6 +106,103 @@ class TestReadSQLTableDataNode:
             assert isinstance(pandas_data, pd.DataFrame)
             assert pandas_data.equals(pd.DataFrame(self.mock_read_value()))
 
+    def test_build_connection_string(self):
+        sql_properties = {
+            "db_username": "sa",
+            "db_password": "Passw0rd",
+            "db_name": "taipy",
+            "db_engine": "mssql",
+            "table_name": "example",
+            "db_driver": "default server",
+            "db_extra_args": {
+                "TrustServerCertificate": "yes",
+                "other": "value",
+            },
+        }
+        custom_properties = sql_properties.copy()
+        mssql_sql_data_node = SQLTableDataNode(
+            "foo",
+            Scope.SCENARIO,
+            properties=custom_properties,
+        )
+        assert (
+            mssql_sql_data_node._conn_string()
+            == "mssql+pyodbc://sa:Passw0rd@localhost:1433/taipy?TrustServerCertificate=yes&other=value&driver=default+server"
+        )
+
+        custom_properties["db_engine"] = "mysql"
+        mysql_sql_data_node = SQLTableDataNode(
+            "foo",
+            Scope.SCENARIO,
+            properties=custom_properties,
+        )
+        assert (
+            mysql_sql_data_node._conn_string()
+            == "mysql+pymysql://sa:Passw0rd@localhost:1433/taipy?TrustServerCertificate=yes&other=value&driver=default+server"
+        )
+
+        custom_properties["db_engine"] = "postgresql"
+        postgresql_sql_data_node = SQLTableDataNode(
+            "foo",
+            Scope.SCENARIO,
+            properties=custom_properties,
+        )
+        assert (
+            postgresql_sql_data_node._conn_string()
+            == "postgresql+psycopg2://sa:Passw0rd@localhost:1433/taipy?TrustServerCertificate=yes&other=value&driver=default+server"
+        )
+
+        custom_properties["db_engine"] = "sqlite"
+        sqlite_sql_data_node = SQLTableDataNode(
+            "foo",
+            Scope.SCENARIO,
+            properties=custom_properties,
+        )
+        assert sqlite_sql_data_node._conn_string() == "sqlite:///taipy.db"
+
+    @pytest.mark.parametrize("sql_properties", __sql_properties)
+    def test_get_read_query(self, sql_properties):
+        custom_properties = sql_properties.copy()
+
+        sql_data_node = SQLTableDataNode(
+            "foo",
+            Scope.SCENARIO,
+            properties=custom_properties,
+        )
+
+        assert sql_data_node._get_read_query(("key", 1, Operator.EQUAL)) == "SELECT * FROM example WHERE key = '1'"
+        assert sql_data_node._get_read_query(("key", 1, Operator.NOT_EQUAL)) == "SELECT * FROM example WHERE key <> '1'"
+        assert (
+            sql_data_node._get_read_query(("key", 1, Operator.GREATER_THAN)) == "SELECT * FROM example WHERE key > '1'"
+        )
+        assert (
+            sql_data_node._get_read_query(("key", 1, Operator.GREATER_OR_EQUAL))
+            == "SELECT * FROM example WHERE key >= '1'"
+        )
+        assert sql_data_node._get_read_query(("key", 1, Operator.LESS_THAN)) == "SELECT * FROM example WHERE key < '1'"
+        assert (
+            sql_data_node._get_read_query(("key", 1, Operator.LESS_OR_EQUAL))
+            == "SELECT * FROM example WHERE key <= '1'"
+        )
+
+        with pytest.raises(NotImplementedError):
+            sql_data_node._get_read_query(
+                [("key", 1, Operator.EQUAL), ("key2", 2, Operator.GREATER_THAN)], "SOME JoinOperator"
+            )
+
+        assert (
+            sql_data_node._get_read_query(
+                [("key", 1, Operator.EQUAL), ("key2", 2, Operator.GREATER_THAN)], JoinOperator.AND
+            )
+            == "SELECT * FROM example WHERE key = '1' AND key2 > '2'"
+        )
+        assert (
+            sql_data_node._get_read_query(
+                [("key", 1, Operator.EQUAL), ("key2", 2, Operator.GREATER_THAN)], JoinOperator.OR
+            )
+            == "SELECT * FROM example WHERE key = '1' OR key2 > '2'"
+        )
+
     @pytest.mark.parametrize("sql_properties", __sql_properties)
     def test_read_numpy(self, sql_properties):
         custom_properties = sql_properties.copy()

+ 40 - 10
tests/core/data/test_sql_data_node.py

@@ -23,7 +23,7 @@ from taipy.core.data._data_manager_factory import _DataManagerFactory
 from taipy.core.data.data_node_id import DataNodeId
 from taipy.core.data.operator import JoinOperator, Operator
 from taipy.core.data.sql import SQLDataNode
-from taipy.core.exceptions.exceptions import MissingAppendQueryBuilder, MissingRequiredProperty
+from taipy.core.exceptions.exceptions import MissingAppendQueryBuilder, MissingRequiredProperty, UnknownDatabaseEngine
 
 
 class MyCustomObject:
@@ -149,20 +149,50 @@ class TestSQLDataNode:
         "properties",
         [
             {},
-            {"db_username": "foo"},
-            {"db_username": "foo", "db_password": "foo"},
-            {"db_username": "foo", "db_password": "foo", "db_name": "foo"},
-            {"engine": "sqlite"},
-            {"engine": "mssql", "db_name": "foo"},
-            {"engine": "mysql", "db_username": "foo"},
-            {"engine": "postgresql", "db_username": "foo", "db_password": "foo"},
+            {"read_query": "ready query"},
+            {"read_query": "ready query", "write_query_builder": "write query"},
+            {"read_query": "ready query", "write_query_builder": "write query", "db_username": "foo"},
+            {
+                "read_query": "ready query",
+                "write_query_builder": "write query",
+                "db_username": "foo",
+                "db_password": "foo",
+            },
+            {
+                "read_query": "ready query",
+                "write_query_builder": "write query",
+                "db_username": "foo",
+                "db_password": "foo",
+                "db_name": "foo",
+            },
+            {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "some engine"},
+            {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "sqlite"},
+            {"read_query": "ready query", "write_query_builder": "write query", "db_engine": "mssql", "db_name": "foo"},
+            {
+                "read_query": "ready query",
+                "write_query_builder": "write query",
+                "db_engine": "mysql",
+                "db_username": "foo",
+            },
+            {
+                "read_query": "ready query",
+                "write_query_builder": "write query",
+                "db_engine": "postgresql",
+                "db_username": "foo",
+                "db_password": "foo",
+            },
         ],
     )
     def test_create_with_missing_parameters(self, properties):
         with pytest.raises(MissingRequiredProperty):
             SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"))
-        with pytest.raises(MissingRequiredProperty):
-            SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
+        engine = properties.get("db_engine")
+        if engine is not None and engine not in ["sqlite", "mssql", "mysql", "postgresql"]:
+            with pytest.raises(UnknownDatabaseEngine):
+                SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
+        else:
+            with pytest.raises(MissingRequiredProperty):
+                SQLDataNode("foo", Scope.SCENARIO, DataNodeId("dn_id"), properties=properties)
 
     @pytest.mark.parametrize("properties", __sql_properties)
     def test_write_query_builder(self, properties):