from sqlalchemy import create_engine
from sqlalchemy import MetaData, Table, Column, Text
from sqlalchemy.pool import QueuePool, SingletonThreadPool
+from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint,
+ CreateIndex)
from sqlalchemy.sql import select, and_
import ConfigParser
import os
import logging
-CURRENT_SCHEMA_VERSION = 1
-OPTIONS_COLUMNS = ['name', 'option', 'value']
-UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
+CURRENT_SCHEMA_VERSION = 2
+OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
+ 'primary_key': ('name', 'option'),
+ 'indexes': [('name',)]
+ }
+UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
+ 'primary_key': ('uuid', 'name'),
+ 'indexes': [('uuid',)]
+ }
class DatabaseError(Exception):
pass
-class SqlStore(Log):
+class BaseStore(Log):
+ # Some helper functions used for upgrades
+ def add_constraint(self, table):
+ raise NotImplementedError()
+
+ def add_index(self, index):
+ raise NotImplementedError()
+
+
+class SqlStore(BaseStore):
__instances = {}
@classmethod
self._dbengine = create_engine(engine_name, **pool_args)
self.is_readonly = False
+ def add_constraint(self, constraint):
+ if self._dbengine.dialect.name != 'sqlite':
+ # It is impossible to add constraints to a pre-existing table for
+ # SQLite
+ # source: http://www.sqlite.org/omitted.html
+ create_constraint = AddConstraint(constraint, bind=self._dbengine)
+ create_constraint.execute()
+
+ def add_index(self, index):
+ add_index = CreateIndex(index, bind=self._dbengine)
+ add_index.execute()
+
def debug(self, fact):
if self.db_conn_log:
super(SqlStore, self).debug(fact)
class SqlQuery(Log):
- def __init__(self, db_obj, table, columns, trans=True):
+ def __init__(self, db_obj, table, table_def, trans=True):
self._db = db_obj
self._con = self._db.connection()
self._trans = self._con.begin() if trans else None
- self._table = self._get_table(table, columns)
-
- def _get_table(self, name, columns):
- table = Table(name, MetaData(self._db.engine()))
- for c in columns:
- table.append_column(Column(c, Text()))
+ self._table = self._get_table(table, table_def)
+
+ def _get_table(self, name, table_def):
+ if isinstance(table_def, list):
+ table_def = {'columns': table_def,
+ 'indexes': [],
+ 'primary_key': None}
+ table_creation = []
+ for col_name in table_def['columns']:
+ table_creation.append(Column(col_name, Text()))
+ if table_def['primary_key']:
+ table_creation.append(PrimaryKeyConstraint(
+ *table_def['primary_key']))
+ for index in table_def['indexes']:
+ idx_name = 'idx_%s_%s' % (name, '_'.join(index))
+ table_creation.append(Index(idx_name, *index))
+ table = Table(name, MetaData(self._db.engine()), *table_creation)
return table
def _where(self, kvfilter):
self._con.execute(self._table.delete(self._where(kvfilter)))
-class FileStore(Log):
+class FileStore(BaseStore):
def __init__(self, name):
self._filename = name
self._config.read(self._filename)
return self._config
+ def add_constraint(self, table):
+ raise NotImplementedError()
+
+ def add_index(self, index):
+ raise NotImplementedError()
+
class FileQuery(Log):
- def __init__(self, fstore, table, columns, trans=True):
+ def __init__(self, fstore, table, table_def, trans=True):
+ # We don't need indexes in a FileQuery, so drop that info
+ if isinstance(table_def, dict):
+ columns = table_def['columns']
+ else:
+ columns = table_def
self._fstore = fstore
self._config = fstore.get_config()
self._section = table
# We are storing multiple versions: one per class
# That way, we can support plugins with differing schema versions from
# the main codebase, and even in the same database.
- q = self._query(self._db, 'dbinfo', OPTIONS_COLUMNS, trans=False)
+ q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
q.create()
cls_name = self.__class__.__name__
current_version = self.load_options('dbinfo').get('%s_schema'
# themselves.
# They might implement downgrading if that's feasible, or just throw
# NotImplementedError
+ # Should return the new schema version
raise NotImplementedError()
def upgrade_database(self):
self._store_new_schema_version(self._code_schema_version())
elif old_schema_version != self._code_schema_version():
# Upgrade from old_schema_version to code_schema_version
- self._upgrade_schema(old_schema_version)
- self._store_new_schema_version(self._code_schema_version())
+ self.debug('Upgrading from schema version %i' % old_schema_version)
+ new_version = self._upgrade_schema(old_schema_version)
+ if not new_version:
+ error = ('Schema upgrade error: %s did not provide a ' +
+ 'new schema version number!' %
+ self.__class__.__name__)
+ self.error(error)
+ raise Exception(error)
+ self._store_new_schema_version(new_version)
+ # Check if we are now up-to-date
+ self.upgrade_database()
@property
def is_readonly(self):
kvfilter = dict()
if name:
kvfilter['name'] = name
- options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
+ options = self._load_data(table, OPTIONS_TABLE, kvfilter)
if name and name in options:
return options[name]
return options
curvals = dict()
q = None
try:
- q = self._query(self._db, table, OPTIONS_COLUMNS)
+ q = self._query(self._db, table, OPTIONS_TABLE)
rows = q.select({'name': name}, ['option', 'value'])
for row in rows:
curvals[row[0]] = row[1]
kvfilter = {'name': name}
q = None
try:
- q = self._query(self._db, table, OPTIONS_COLUMNS)
+ q = self._query(self._db, table, OPTIONS_TABLE)
if options is None:
q.delete(kvfilter)
else:
newid = str(uuid.uuid4())
q = None
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE)
for name in data:
q.insert((newid, name, data[name]))
q.commit()
kvfilter['name'] = name
if value:
kvfilter['value'] = value
- return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
+ return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
def save_unique_data(self, table, data):
q = None
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE)
for uid in data:
curvals = dict()
rows = q.select({'uuid': uid}, ['name', 'value'])
def del_unique_data(self, table, uuidval):
kvfilter = {'uuid': uuidval}
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
q.delete(kvfilter)
except Exception, e: # pylint: disable=broad-except
self.error("Failed to delete data from %s: [%s]" % (table, e))
def _reset_data(self, table):
q = None
try:
- q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE)
q.drop()
q.create()
q.commit()
'info_config',
'login_config',
'provider_config']:
- q = self._query(self._db, table, OPTIONS_COLUMNS, trans=False)
+ q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
q.create()
def _upgrade_schema(self, old_version):
- raise NotImplementedError()
+ if old_version == 1:
+ # In schema version 2, we added indexes and primary keys
+ for table in ['config',
+ 'info_config',
+ 'login_config',
+ 'provider_config']:
+ # pylint: disable=protected-access
+ table = self._query(self._db, table, OPTIONS_TABLE,
+ trans=False)._table
+ self._db.add_constraint(table.primary_key)
+ for index in table.indexes:
+ self._db.add_index(index)
+ return 2
+ else:
+ raise NotImplementedError()
+
+ def create_plugin_data_table(self, plugin_name):
+ table = plugin_name+'_data'
+ q = self._query(self._db, table, UNIQUE_DATA_TABLE,
+ trans=False)
+ q.create()
class UserStore(Store):
return self.load_options(plugin+"_data", user)
def _initialize_schema(self):
- q = self._query(self._db, 'users', OPTIONS_COLUMNS, trans=False)
+ q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
q.create()
def _upgrade_schema(self, old_version):
- raise NotImplementedError()
+ if old_version == 1:
+ # In schema version 2, we added indexes and primary keys
+ # pylint: disable=protected-access
+ table = self._query(self._db, 'users', OPTIONS_TABLE,
+ trans=False)._table
+ self._db.add_constraint(table.primary_key)
+ for index in table.indexes:
+ self._db.add_index(index)
+ return 2
+ else:
+ raise NotImplementedError()
class TranStore(Store):
super(TranStore, self).__init__('transactions.db')
def _initialize_schema(self):
- q = self._query(self._db, 'transactions', UNIQUE_DATA_COLUMNS,
+ q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
trans=False)
q.create()
def _upgrade_schema(self, old_version):
- raise NotImplementedError()
+ if old_version == 1:
+ # In schema version 2, we added indexes and primary keys
+ # pylint: disable=protected-access
+ table = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
+ trans=False)._table
+ self._db.add_constraint(table.primary_key)
+ for index in table.indexes:
+ self._db.add_index(index)
+ return 2
+ else:
+ raise NotImplementedError()
class SAML2SessionStore(Store):
super(SAML2SessionStore, self).__init__(database_url=database_url)
self.table = 'saml2_sessions'
# pylint: disable=protected-access
- table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
table.create(checkfirst=True)
def _get_unique_id_from_column(self, name, value):
def remove_expired_sessions(self):
# pylint: disable=protected-access
- table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
sel = select([table.columns.uuid]). \
where(and_(table.c.name == 'expiration_time',
table.c.value <= datetime.datetime.now()))
self._reset_data(self.table)
def _initialize_schema(self):
- q = self._query(self._db, self.table, UNIQUE_DATA_COLUMNS,
+ q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
trans=False)
q.create()
def _upgrade_schema(self, old_version):
- raise NotImplementedError()
+ if old_version == 1:
+ # In schema version 2, we added indexes and primary keys
+ # pylint: disable=protected-access
+ table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
+ trans=False)._table
+ self._db.add_constraint(table.primary_key)
+ for index in table.indexes:
+ self._db.add_index(index)
+ return 2
+ else:
+ raise NotImplementedError()