_abstract_sql.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import re
  13. import urllib.parse
  14. from abc import abstractmethod
  15. from datetime import datetime, timedelta
  16. from typing import Dict, List, Optional, Set, Tuple, Union
  17. import numpy as np
  18. import pandas as pd
  19. from sqlalchemy import create_engine, text
  20. from taipy.config.common.scope import Scope
  21. from .._version._version_manager_factory import _VersionManagerFactory
  22. from ..data.operator import JoinOperator, Operator
  23. from ..exceptions.exceptions import MissingRequiredProperty, UnknownDatabaseEngine
  24. from ._tabular_datanode_mixin import _TabularDataNodeMixin
  25. from .data_node import DataNode
  26. from .data_node_id import DataNodeId, Edit
  27. class _AbstractSQLDataNode(DataNode, _TabularDataNodeMixin):
  28. """Abstract base class for data node implementations (SQLDataNode and SQLTableDataNode) that use SQL."""
  29. __STORAGE_TYPE = "NOT_IMPLEMENTED"
  30. __DB_NAME_KEY = "db_name"
  31. __DB_USERNAME_KEY = "db_username"
  32. __DB_PASSWORD_KEY = "db_password"
  33. __DB_HOST_KEY = "db_host"
  34. __DB_PORT_KEY = "db_port"
  35. __DB_ENGINE_KEY = "db_engine"
  36. __DB_DRIVER_KEY = "db_driver"
  37. __DB_EXTRA_ARGS_KEY = "db_extra_args"
  38. __SQLITE_FOLDER_PATH = "sqlite_folder_path"
  39. __SQLITE_FILE_EXTENSION = "sqlite_file_extension"
  40. __ENGINE_PROPERTIES: List[str] = [
  41. __DB_NAME_KEY,
  42. __DB_USERNAME_KEY,
  43. __DB_PASSWORD_KEY,
  44. __DB_HOST_KEY,
  45. __DB_PORT_KEY,
  46. __DB_DRIVER_KEY,
  47. __DB_EXTRA_ARGS_KEY,
  48. __SQLITE_FOLDER_PATH,
  49. __SQLITE_FILE_EXTENSION,
  50. ]
  51. __DB_HOST_DEFAULT = "localhost"
  52. __DB_PORT_DEFAULT = 1433
  53. __DB_DRIVER_DEFAULT = ""
  54. __SQLITE_FOLDER_PATH_DEFAULT = ""
  55. __SQLITE_FILE_EXTENSION_DEFAULT = ".db"
  56. __ENGINE_MSSQL = "mssql"
  57. __ENGINE_SQLITE = "sqlite"
  58. __ENGINE_MYSQL = "mysql"
  59. __ENGINE_POSTGRESQL = "postgresql"
  60. _ENGINE_REQUIRED_PROPERTIES: Dict[str, List[str]] = {
  61. __ENGINE_MSSQL: [__DB_USERNAME_KEY, __DB_PASSWORD_KEY, __DB_NAME_KEY],
  62. __ENGINE_MYSQL: [__DB_USERNAME_KEY, __DB_PASSWORD_KEY, __DB_NAME_KEY],
  63. __ENGINE_POSTGRESQL: [__DB_USERNAME_KEY, __DB_PASSWORD_KEY, __DB_NAME_KEY],
  64. __ENGINE_SQLITE: [__DB_NAME_KEY],
  65. }
  66. def __init__(
  67. self,
  68. config_id: str,
  69. scope: Scope,
  70. id: Optional[DataNodeId] = None,
  71. owner_id: Optional[str] = None,
  72. parent_ids: Optional[Set[str]] = None,
  73. last_edit_date: Optional[datetime] = None,
  74. edits: Optional[List[Edit]] = None,
  75. version: Optional[str] = None,
  76. validity_period: Optional[timedelta] = None,
  77. edit_in_progress: bool = False,
  78. editor_id: Optional[str] = None,
  79. editor_expiration_date: Optional[datetime] = None,
  80. properties: Optional[Dict] = None,
  81. ) -> None:
  82. if properties is None:
  83. properties = {}
  84. self._check_required_properties(properties)
  85. properties[self._EXPOSED_TYPE_PROPERTY] = _TabularDataNodeMixin._get_valid_exposed_type(properties)
  86. self._check_exposed_type(properties[self._EXPOSED_TYPE_PROPERTY])
  87. DataNode.__init__(
  88. self,
  89. config_id,
  90. scope,
  91. id,
  92. owner_id,
  93. parent_ids,
  94. last_edit_date,
  95. edits,
  96. version or _VersionManagerFactory._build_manager()._get_latest_version(),
  97. validity_period,
  98. edit_in_progress,
  99. editor_id,
  100. editor_expiration_date,
  101. **properties,
  102. )
  103. _TabularDataNodeMixin.__init__(self, **properties)
  104. self._engine = None
  105. if not self._last_edit_date: # type: ignore
  106. self._last_edit_date = datetime.now()
  107. self._TAIPY_PROPERTIES.update(
  108. {
  109. self.__DB_NAME_KEY,
  110. self.__DB_USERNAME_KEY,
  111. self.__DB_PASSWORD_KEY,
  112. self.__DB_HOST_KEY,
  113. self.__DB_PORT_KEY,
  114. self.__DB_ENGINE_KEY,
  115. self.__DB_DRIVER_KEY,
  116. self.__DB_EXTRA_ARGS_KEY,
  117. self.__SQLITE_FOLDER_PATH,
  118. self.__SQLITE_FILE_EXTENSION,
  119. self._EXPOSED_TYPE_PROPERTY,
  120. }
  121. )
  122. def _check_required_properties(self, properties: Dict):
  123. db_engine = properties.get(self.__DB_ENGINE_KEY)
  124. if not db_engine:
  125. raise MissingRequiredProperty(f"{self.__DB_ENGINE_KEY} is required.")
  126. if db_engine not in self._ENGINE_REQUIRED_PROPERTIES.keys():
  127. raise UnknownDatabaseEngine(f"Unknown engine: {db_engine}")
  128. required = self._ENGINE_REQUIRED_PROPERTIES[db_engine]
  129. if missing := set(required) - set(properties.keys()):
  130. raise MissingRequiredProperty(
  131. f"The following properties {', '.join(missing)} were not informed and are required."
  132. )
  133. def _get_engine(self):
  134. if self._engine is None:
  135. self._engine = create_engine(self._conn_string())
  136. return self._engine
  137. def _conn_string(self) -> str:
  138. properties = self.properties
  139. engine = properties.get(self.__DB_ENGINE_KEY)
  140. if self.__DB_USERNAME_KEY in self._ENGINE_REQUIRED_PROPERTIES[engine]:
  141. username = properties.get(self.__DB_USERNAME_KEY)
  142. username = urllib.parse.quote_plus(username)
  143. if self.__DB_PASSWORD_KEY in self._ENGINE_REQUIRED_PROPERTIES[engine]:
  144. password = properties.get(self.__DB_PASSWORD_KEY)
  145. password = urllib.parse.quote_plus(password)
  146. if self.__DB_NAME_KEY in self._ENGINE_REQUIRED_PROPERTIES[engine]:
  147. db_name = properties.get(self.__DB_NAME_KEY)
  148. db_name = urllib.parse.quote_plus(db_name)
  149. host = properties.get(self.__DB_HOST_KEY, self.__DB_HOST_DEFAULT)
  150. port = properties.get(self.__DB_PORT_KEY, self.__DB_PORT_DEFAULT)
  151. driver = properties.get(self.__DB_DRIVER_KEY, self.__DB_DRIVER_DEFAULT)
  152. extra_args = properties.get(self.__DB_EXTRA_ARGS_KEY, {})
  153. if driver:
  154. extra_args = {**extra_args, "driver": driver}
  155. for k, v in extra_args.items():
  156. extra_args[k] = re.sub(r"\s+", "+", v)
  157. extra_args_str = "&".join(f"{k}={str(v)}" for k, v in extra_args.items())
  158. if engine == self.__ENGINE_MSSQL:
  159. return f"mssql+pyodbc://{username}:{password}@{host}:{port}/{db_name}?{extra_args_str}"
  160. elif engine == self.__ENGINE_MYSQL:
  161. return f"mysql+pymysql://{username}:{password}@{host}:{port}/{db_name}?{extra_args_str}"
  162. elif engine == self.__ENGINE_POSTGRESQL:
  163. return f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{db_name}?{extra_args_str}"
  164. elif engine == self.__ENGINE_SQLITE:
  165. folder_path = properties.get(self.__SQLITE_FOLDER_PATH, self.__SQLITE_FOLDER_PATH_DEFAULT)
  166. file_extension = properties.get(self.__SQLITE_FILE_EXTENSION, self.__SQLITE_FILE_EXTENSION_DEFAULT)
  167. return "sqlite:///" + os.path.join(folder_path, f"{db_name}{file_extension}")
  168. raise UnknownDatabaseEngine(f"Unknown engine: {engine}")
  169. def filter(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
  170. properties = self.properties
  171. if properties[self._EXPOSED_TYPE_PROPERTY] == self._EXPOSED_TYPE_PANDAS:
  172. return self._read_as_pandas_dataframe(operators=operators, join_operator=join_operator)
  173. if properties[self._EXPOSED_TYPE_PROPERTY] == self._EXPOSED_TYPE_NUMPY:
  174. return self._read_as_numpy(operators=operators, join_operator=join_operator)
  175. return self._read_as(operators=operators, join_operator=join_operator)
  176. def _read(self):
  177. properties = self.properties
  178. if properties[self._EXPOSED_TYPE_PROPERTY] == self._EXPOSED_TYPE_PANDAS:
  179. return self._read_as_pandas_dataframe()
  180. if properties[self._EXPOSED_TYPE_PROPERTY] == self._EXPOSED_TYPE_NUMPY:
  181. return self._read_as_numpy()
  182. return self._read_as()
  183. def _read_as(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
  184. custom_class = self.properties[self._EXPOSED_TYPE_PROPERTY]
  185. with self._get_engine().connect() as connection:
  186. query_result = connection.execute(text(self._get_read_query(operators, join_operator)))
  187. return [custom_class(**row) for row in query_result]
  188. def _read_as_numpy(
  189. self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND
  190. ) -> np.ndarray:
  191. return self._read_as_pandas_dataframe(operators=operators, join_operator=join_operator).to_numpy()
  192. def _read_as_pandas_dataframe(
  193. self,
  194. columns: Optional[List[str]] = None,
  195. operators: Optional[Union[List, Tuple]] = None,
  196. join_operator=JoinOperator.AND,
  197. ):
  198. with self._get_engine().connect() as conn:
  199. result = conn.execute(text(self._get_read_query(operators, join_operator)))
  200. # On pandas 1.3.5 there's a bug that makes that the dataframe from sqlalchemy query is
  201. # created without headers
  202. keys = list(result.keys())
  203. if columns:
  204. return pd.DataFrame(result, columns=keys)[columns]
  205. return pd.DataFrame(result, columns=keys)
  206. @abstractmethod
  207. def _get_read_query(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
  208. query = self._get_base_read_query()
  209. if not operators:
  210. return query
  211. if not isinstance(operators, List):
  212. operators = [operators]
  213. conditions = []
  214. for key, value, operator in operators:
  215. if operator == Operator.EQUAL:
  216. conditions.append(f"{key} = '{value}'")
  217. elif operator == Operator.NOT_EQUAL:
  218. conditions.append(f"{key} <> '{value}'")
  219. elif operator == Operator.GREATER_THAN:
  220. conditions.append(f"{key} > '{value}'")
  221. elif operator == Operator.GREATER_OR_EQUAL:
  222. conditions.append(f"{key} >= '{value}'")
  223. elif operator == Operator.LESS_THAN:
  224. conditions.append(f"{key} < '{value}'")
  225. elif operator == Operator.LESS_OR_EQUAL:
  226. conditions.append(f"{key} <= '{value}'")
  227. if join_operator == JoinOperator.AND:
  228. query += f" WHERE {' AND '.join(conditions)}"
  229. elif join_operator == JoinOperator.OR:
  230. query += f" WHERE {' OR '.join(conditions)}"
  231. else:
  232. raise NotImplementedError(f"Join operator {join_operator} not implemented.")
  233. return query
  234. @abstractmethod
  235. def _get_base_read_query(self) -> str:
  236. raise NotImplementedError
  237. def _append(self, data) -> None:
  238. engine = self._get_engine()
  239. with engine.connect() as connection:
  240. with connection.begin() as transaction:
  241. try:
  242. self._do_append(data, engine, connection)
  243. except Exception as e:
  244. transaction.rollback()
  245. raise e
  246. else:
  247. transaction.commit()
  248. @abstractmethod
  249. def _do_append(self, data, engine, connection) -> None:
  250. raise NotImplementedError
  251. def _write(self, data) -> None:
  252. """Check data against a collection of types to handle insertion on the database."""
  253. engine = self._get_engine()
  254. with engine.connect() as connection:
  255. with connection.begin() as transaction:
  256. try:
  257. self._do_write(data, engine, connection)
  258. except Exception as e:
  259. transaction.rollback()
  260. raise e
  261. else:
  262. transaction.commit()
  263. @abstractmethod
  264. def _do_write(self, data, engine, connection) -> None:
  265. raise NotImplementedError
  266. def __setattr__(self, key: str, value) -> None:
  267. if key in self.__ENGINE_PROPERTIES:
  268. self._engine = None
  269. return super().__setattr__(key, value)