Browse Source

feat(migrate_cli): ensure connection is closed

João André 11 months ago
parent
commit
918916ed96
2 changed files with 44 additions and 42 deletions
  1. 44 41
      taipy/core/_entity/_migrate/_migrate_sql.py
  2. 0 1
      tests/core/_entity/test_migrate_cli.py

+ 44 - 41
taipy/core/_entity/_migrate/_migrate_sql.py

@@ -9,6 +9,7 @@
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
 
+from contextlib import closing
 import json
 import os
 import shutil
@@ -24,32 +25,34 @@ __logger = _TaipyLogger._get_logger()
 
 def _load_all_entities_from_sql(db_file: str) -> Tuple[Dict, Dict]:
     conn = sqlite3.connect(db_file)
-    query = "SELECT model_id, document FROM taipy_model"
-    query_version = "SELECT * FROM taipy_version"
-    cursor = conn.execute(query)
-    entities = {}
-    versions = {}
-
-    for row in cursor:
-        _id = row[0]
-        document = row[1]
-        entities[_id] = {"data": json.loads(document)}
-
-    cursor = conn.execute(query_version)
-    for row in cursor:
-        id = row[0]
-        config_id = row[1]
-        creation_date = row[2]
-        is_production = row[3]
-        is_development = row[4]
-        is_latest = row[5]
-        versions[id] = {
-            "config_id": config_id,
-            "creation_date": creation_date,
-            "is_production": is_production,
-            "is_development": is_development,
-            "is_latest": is_latest,
-        }
+    with closing(conn):
+        query = "SELECT model_id, document FROM taipy_model"
+        query_version = "SELECT * FROM taipy_version"
+        cursor = conn.execute(query)
+        entities = {}
+        versions = {}
+
+        for row in cursor:
+            _id = row[0]
+            document = row[1]
+            entities[_id] = {"data": json.loads(document)}
+
+        cursor = conn.execute(query_version)
+        for row in cursor:
+            id = row[0]
+            config_id = row[1]
+            creation_date = row[2]
+            is_production = row[3]
+            is_development = row[4]
+            is_latest = row[5]
+            versions[id] = {
+                "config_id": config_id,
+                "creation_date": creation_date,
+                "is_production": is_production,
+                "is_development": is_development,
+                "is_latest": is_latest,
+            }
+
     return entities, versions
 
 
@@ -123,21 +126,21 @@ def __insert_version(version: dict, conn):
 
 def __write_entities_to_sql(_entities: Dict, _versions: Dict, db_file: str):
     conn = sqlite3.connect(db_file)
-
-    for k, entity in _entities.items():
-        if "SCENARIO" in k:
-            __insert_scenario(entity["data"], conn)
-        elif "TASK" in k:
-            __insert_task(entity["data"], conn)
-        elif "DATANODE" in k:
-            __insert_datanode(entity["data"], conn)
-        elif "JOB" in k:
-            __insert_job(entity["data"], conn)
-        elif "CYCLE" in k:
-            __insert_cycle(entity["data"], conn)
-
-    for _, version in _versions.items():
-        __insert_version(version, conn)
+    with closing(conn):
+        for k, entity in _entities.items():
+            if "SCENARIO" in k:
+                __insert_scenario(entity["data"], conn)
+            elif "TASK" in k:
+                __insert_task(entity["data"], conn)
+            elif "DATANODE" in k:
+                __insert_datanode(entity["data"], conn)
+            elif "JOB" in k:
+                __insert_job(entity["data"], conn)
+            elif "CYCLE" in k:
+                __insert_cycle(entity["data"], conn)
+
+        for _, version in _versions.items():
+            __insert_version(version, conn)
 
 
 def _restore_migrate_sql_entities(path: str) -> bool:

+ 0 - 1
tests/core/_entity/test_migrate_cli.py

@@ -207,7 +207,6 @@ def test_migrate_sql_backup_and_remove(caplog, tmp_sqlite):
     assert not os.path.exists(backup_sqlite)
 
 
-@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on windows due to PermissionError: [WinError 32]")
 def test_migrate_sql_backup_and_restore(caplog, tmp_sqlite):
     _MigrateCLI.create_parser()