mirror of
https://github.com/clinton-hall/nzbToMedia.git
synced 2024-11-14 09:30:21 -08:00
341 lines
12 KiB
Python
341 lines
12 KiB
Python
# coding=utf-8
|
|
|
|
from __future__ import (
|
|
absolute_import,
|
|
division,
|
|
print_function,
|
|
unicode_literals,
|
|
)
|
|
|
|
import os.path
|
|
import re
|
|
import sqlite3
|
|
import sys
|
|
import time
|
|
|
|
from six import text_type, PY2
|
|
|
|
import core
|
|
from core import logger
|
|
from core import permissions
|
|
|
|
if PY2:
|
|
class Row(sqlite3.Row, object):
|
|
"""
|
|
Row factory that uses Byte Strings for keys.
|
|
|
|
The sqlite3.Row in Python 2 does not support unicode keys.
|
|
This overrides __getitem__ to attempt to encode the key to bytes first.
|
|
"""
|
|
|
|
def __getitem__(self, item):
|
|
"""
|
|
Get an item from the row by index or key.
|
|
|
|
:param item: Index or Key of item to return.
|
|
:return: An item from the sqlite3.Row.
|
|
"""
|
|
try:
|
|
# sqlite3.Row column names should be Bytes in Python 2
|
|
item = item.encode()
|
|
except AttributeError:
|
|
pass # assume item is a numeric index
|
|
|
|
return super(Row, self).__getitem__(item)
|
|
else:
|
|
from sqlite3 import Row
|
|
|
|
|
|
def db_filename(filename='nzbtomedia.db', suffix=None):
|
|
"""
|
|
Return the correct location of the database file.
|
|
|
|
@param filename: The sqlite database filename to use. If not specified,
|
|
will be made to be nzbtomedia.db
|
|
@param suffix: The suffix to append to the filename. A '.' will be added
|
|
automatically, i.e. suffix='v0' will make dbfile.db.v0
|
|
@return: the correct location of the database file.
|
|
"""
|
|
if suffix:
|
|
filename = '{0}.{1}'.format(filename, suffix)
|
|
return core.os.path.join(core.APP_ROOT, filename)
|
|
|
|
|
|
class DBConnection(object):
|
|
def __init__(self, filename='nzbtomedia.db', suffix=None, row_type=None):
|
|
self.filename = filename
|
|
path = db_filename(filename)
|
|
try:
|
|
self.connection = sqlite3.connect(path, 20)
|
|
except sqlite3.OperationalError as error:
|
|
if os.path.exists(path):
|
|
logger.error('Please check permissions on database: {0}'.format(path))
|
|
else:
|
|
logger.error('Database file does not exist')
|
|
logger.error('Please check permissions on directory: {0}'.format(path))
|
|
path = os.path.dirname(path)
|
|
mode = permissions.mode(path)
|
|
owner, group = permissions.ownership(path)
|
|
logger.error(
|
|
"=== PERMISSIONS ===========================\n"
|
|
" Path : {0}\n"
|
|
" Mode : {1}\n"
|
|
" Owner: {2}\n"
|
|
" Group: {3}\n"
|
|
"===========================================".format(path, mode, owner, group),
|
|
)
|
|
else:
|
|
self.connection.row_factory = Row
|
|
|
|
def check_db_version(self):
|
|
result = None
|
|
try:
|
|
result = self.select('SELECT db_version FROM db_version')
|
|
except sqlite3.OperationalError as e:
|
|
if 'no such table: db_version' in e.args[0]:
|
|
return 0
|
|
|
|
if result:
|
|
return int(result[0]['db_version'])
|
|
else:
|
|
return 0
|
|
|
|
def fetch(self, query, args=None):
|
|
if query is None:
|
|
return
|
|
|
|
sql_result = None
|
|
attempt = 0
|
|
|
|
while attempt < 5:
|
|
try:
|
|
if args is None:
|
|
logger.log('{name}: {query}'.format(name=self.filename, query=query), logger.DB)
|
|
cursor = self.connection.cursor()
|
|
cursor.execute(query)
|
|
sql_result = cursor.fetchone()[0]
|
|
else:
|
|
logger.log('{name}: {query} with args {args}'.format
|
|
(name=self.filename, query=query, args=args), logger.DB)
|
|
cursor = self.connection.cursor()
|
|
cursor.execute(query, args)
|
|
sql_result = cursor.fetchone()[0]
|
|
|
|
# get out of the connection attempt loop since we were successful
|
|
break
|
|
except sqlite3.OperationalError as error:
|
|
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
|
|
logger.log(u'DB error: {msg}'.format(msg=error), logger.WARNING)
|
|
attempt += 1
|
|
time.sleep(1)
|
|
else:
|
|
logger.log(u'DB error: {msg}'.format(msg=error), logger.ERROR)
|
|
raise
|
|
except sqlite3.DatabaseError as error:
|
|
logger.log(u'Fatal error executing query: {msg}'.format(msg=error), logger.ERROR)
|
|
raise
|
|
|
|
return sql_result
|
|
|
|
def mass_action(self, querylist, log_transaction=False):
|
|
if querylist is None:
|
|
return
|
|
|
|
sql_result = []
|
|
attempt = 0
|
|
|
|
while attempt < 5:
|
|
try:
|
|
for qu in querylist:
|
|
if len(qu) == 1:
|
|
if log_transaction:
|
|
logger.log(qu[0], logger.DEBUG)
|
|
sql_result.append(self.connection.execute(qu[0]))
|
|
elif len(qu) > 1:
|
|
if log_transaction:
|
|
logger.log(u'{query} with args {args}'.format(query=qu[0], args=qu[1]), logger.DEBUG)
|
|
sql_result.append(self.connection.execute(qu[0], qu[1]))
|
|
self.connection.commit()
|
|
logger.log(u'Transaction with {x} query\'s executed'.format(x=len(querylist)), logger.DEBUG)
|
|
return sql_result
|
|
except sqlite3.OperationalError as error:
|
|
sql_result = []
|
|
if self.connection:
|
|
self.connection.rollback()
|
|
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
|
|
logger.log(u'DB error: {msg}'.format(msg=error), logger.WARNING)
|
|
attempt += 1
|
|
time.sleep(1)
|
|
else:
|
|
logger.log(u'DB error: {msg}'.format(msg=error), logger.ERROR)
|
|
raise
|
|
except sqlite3.DatabaseError as error:
|
|
if self.connection:
|
|
self.connection.rollback()
|
|
logger.log(u'Fatal error executing query: {msg}'.format(msg=error), logger.ERROR)
|
|
raise
|
|
|
|
return sql_result
|
|
|
|
def action(self, query, args=None):
|
|
if query is None:
|
|
return
|
|
|
|
sql_result = None
|
|
attempt = 0
|
|
|
|
while attempt < 5:
|
|
try:
|
|
if args is None:
|
|
logger.log(u'{name}: {query}'.format(name=self.filename, query=query), logger.DB)
|
|
sql_result = self.connection.execute(query)
|
|
else:
|
|
logger.log(u'{name}: {query} with args {args}'.format
|
|
(name=self.filename, query=query, args=args), logger.DB)
|
|
sql_result = self.connection.execute(query, args)
|
|
self.connection.commit()
|
|
# get out of the connection attempt loop since we were successful
|
|
break
|
|
except sqlite3.OperationalError as error:
|
|
if 'unable to open database file' in error.args[0] or 'database is locked' in error.args[0]:
|
|
logger.log(u'DB error: {msg}'.format(msg=error), logger.WARNING)
|
|
attempt += 1
|
|
time.sleep(1)
|
|
else:
|
|
logger.log(u'DB error: {msg}'.format(msg=error), logger.ERROR)
|
|
raise
|
|
except sqlite3.DatabaseError as error:
|
|
logger.log(u'Fatal error executing query: {msg}'.format(msg=error), logger.ERROR)
|
|
raise
|
|
|
|
return sql_result
|
|
|
|
def select(self, query, args=None):
|
|
|
|
sql_results = self.action(query, args).fetchall()
|
|
|
|
if sql_results is None:
|
|
return []
|
|
|
|
return sql_results
|
|
|
|
def upsert(self, table_name, value_dict, key_dict):
|
|
|
|
def gen_params(my_dict):
|
|
return [
|
|
'{key} = ?'.format(key=k)
|
|
for k in my_dict.keys()
|
|
]
|
|
|
|
changes_before = self.connection.total_changes
|
|
items = list(value_dict.values()) + list(key_dict.values())
|
|
self.action(
|
|
'UPDATE {table} '
|
|
'SET {params} '
|
|
'WHERE {conditions}'.format(
|
|
table=table_name,
|
|
params=', '.join(gen_params(value_dict)),
|
|
conditions=' AND '.join(gen_params(key_dict)),
|
|
),
|
|
items,
|
|
)
|
|
|
|
if self.connection.total_changes == changes_before:
|
|
self.action(
|
|
'INSERT OR IGNORE INTO {table} ({columns}) '
|
|
'VALUES ({values})'.format(
|
|
table=table_name,
|
|
columns=', '.join(map(text_type, value_dict.keys())),
|
|
values=', '.join(['?'] * len(value_dict.values())),
|
|
),
|
|
list(value_dict.values()),
|
|
)
|
|
|
|
def table_info(self, table_name):
|
|
# FIXME ? binding is not supported here, but I cannot find a way to escape a string manually
|
|
cursor = self.connection.execute('PRAGMA table_info({0})'.format(table_name))
|
|
return {
|
|
column['name']: {'type': column['type']}
|
|
for column in cursor
|
|
}
|
|
|
|
|
|
def sanity_check_database(connection, sanity_check):
|
|
sanity_check(connection).check()
|
|
|
|
|
|
class DBSanityCheck(object):
|
|
def __init__(self, connection):
|
|
self.connection = connection
|
|
|
|
def check(self):
|
|
pass
|
|
|
|
|
|
# ===============
|
|
# = Upgrade API =
|
|
# ===============
|
|
|
|
def upgrade_database(connection, schema):
|
|
logger.log(u'Checking database structure...', logger.MESSAGE)
|
|
try:
|
|
_process_upgrade(connection, schema)
|
|
except Exception as error:
|
|
logger.error(error)
|
|
sys.exit(1)
|
|
|
|
|
|
def pretty_name(class_name):
|
|
return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)])
|
|
|
|
|
|
def _process_upgrade(connection, upgrade_class):
|
|
instance = upgrade_class(connection)
|
|
logger.log(u'Checking {name} database upgrade'.format
|
|
(name=pretty_name(upgrade_class.__name__)), logger.DEBUG)
|
|
if not instance.test():
|
|
logger.log(u'Database upgrade required: {name}'.format
|
|
(name=pretty_name(upgrade_class.__name__)), logger.MESSAGE)
|
|
try:
|
|
instance.execute()
|
|
except sqlite3.DatabaseError as error:
|
|
print(u'Error in {name}: {msg}'.format
|
|
(name=upgrade_class.__name__, msg=error))
|
|
raise
|
|
logger.log(u'{name} upgrade completed'.format
|
|
(name=upgrade_class.__name__), logger.DEBUG)
|
|
else:
|
|
logger.log(u'{name} upgrade not required'.format
|
|
(name=upgrade_class.__name__), logger.DEBUG)
|
|
|
|
for upgradeSubClass in upgrade_class.__subclasses__():
|
|
_process_upgrade(connection, upgradeSubClass)
|
|
|
|
|
|
# Base migration class. All future DB changes should be subclassed from this class
|
|
class SchemaUpgrade(object):
|
|
def __init__(self, connection):
|
|
self.connection = connection
|
|
|
|
def has_table(self, table_name):
|
|
return len(self.connection.action('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)).fetchall()) > 0
|
|
|
|
def has_column(self, table_name, column):
|
|
return column in self.connection.table_info(table_name)
|
|
|
|
def add_column(self, table, column, data_type='NUMERIC', default=0):
|
|
self.connection.action('ALTER TABLE {0} ADD {1} {2}'.format(table, column, data_type))
|
|
self.connection.action('UPDATE {0} SET {1} = ?'.format(table, column), (default,))
|
|
|
|
def check_db_version(self):
|
|
result = self.connection.select('SELECT db_version FROM db_version')
|
|
if result:
|
|
return int(result[-1]['db_version'])
|
|
else:
|
|
return 0
|
|
|
|
def inc_db_version(self):
|
|
new_version = self.check_db_version() + 1
|
|
self.connection.action('UPDATE db_version SET db_version = ?', [new_version])
|
|
return new_version
|