mirror of https://gitlab.com/pamhyr/pamhyr2
134 lines
3.2 KiB
Python
134 lines
3.2 KiB
Python
# SQL.py -- Pamhyr
|
|
# Copyright (C) 2023-2024 INRAE
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import logging
|
|
import sqlite3
|
|
|
|
from pathlib import Path
|
|
|
|
from tools import timer
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class SQL(object):
|
|
def _init_db_file(self, db):
|
|
exists = Path(db).exists()
|
|
|
|
os.makedirs(
|
|
os.path.dirname(db),
|
|
exist_ok=True
|
|
)
|
|
|
|
self._db = sqlite3.connect(db)
|
|
self._cur = self._db.cursor()
|
|
|
|
if not exists:
|
|
self._create() # Create db
|
|
self._save() # Save
|
|
else:
|
|
self._update() # Update db scheme if necessary
|
|
self._load() # Load data
|
|
|
|
def __init__(self, filename=None):
|
|
self._db = None
|
|
|
|
if filename is not None:
|
|
self._init_db_file(filename)
|
|
|
|
def commit(self):
|
|
logger.debug("SQL - commit")
|
|
self._db.commit()
|
|
|
|
def _close(self):
|
|
self.commit()
|
|
self._db.close()
|
|
|
|
def _fetch_string(self, s):
|
|
return s.replace("'", "'")
|
|
|
|
def _fetch_tuple(self, tup):
|
|
res = []
|
|
for v in tup:
|
|
if type(v) is str:
|
|
v = self._fetch_string(v)
|
|
res.append(v)
|
|
|
|
return res
|
|
|
|
def _fetch_list(self, lst):
|
|
res = []
|
|
for v in lst:
|
|
if type(v) is str:
|
|
v = self._fetch_string(v)
|
|
elif type(v) is tuple:
|
|
v = self._fetch_tuple(v)
|
|
res.append(v)
|
|
|
|
return res
|
|
|
|
def _fetch(self, res, one):
|
|
if one:
|
|
value = res.fetchone()
|
|
else:
|
|
value = res.fetchall()
|
|
res = value
|
|
|
|
if type(value) is list:
|
|
res = self._fetch_list(value)
|
|
elif type(value) is tuple:
|
|
res = self._fetch_tuple(value)
|
|
|
|
return res
|
|
|
|
def _db_format(self, value):
|
|
# Replace ''' by ''' to preserve SQL injection
|
|
if type(value) is str:
|
|
value = value.replace("'", "'")
|
|
return value
|
|
|
|
@timer
|
|
def execute(self, cmd, fetch_one=True, commit=False):
|
|
logger.debug(f"SQL - {cmd}")
|
|
|
|
value = None
|
|
try:
|
|
res = self._cur.execute(cmd)
|
|
|
|
if commit:
|
|
self._db.commit()
|
|
|
|
value = self._fetch(res, fetch_one)
|
|
except Exception as e:
|
|
logger_exception(e)
|
|
finally:
|
|
return value
|
|
|
|
def _create(self):
|
|
logger.warning("TODO: Create")
|
|
|
|
def _update(self):
|
|
logger.warning("TODO: Update")
|
|
|
|
def _save(self):
|
|
logger.warning("TODO: Save")
|
|
|
|
def _load(self):
|
|
logger.warning("TODO: LOAD")
|