X-Git-Url: http://git.cascardo.eti.br/?a=blobdiff_plain;f=ipsilon%2Futil%2Fdata.py;h=8d2a1d5f047f8349ef8165616af0edcbbc7e0272;hb=2ff2f766737abf1615bca802677cb2386b32213d;hp=53a17563657e13704c17779f1d86e41260e0831c;hpb=8445b3297cd0b25989f2575c21bf3426aee7c5ad;p=cascardo%2Fipsilon.git diff --git a/ipsilon/util/data.py b/ipsilon/util/data.py index 53a1756..8d2a1d5 100644 --- a/ipsilon/util/data.py +++ b/ipsilon/util/data.py @@ -6,6 +6,8 @@ from ipsilon.util.log import Log from sqlalchemy import create_engine from sqlalchemy import MetaData, Table, Column, Text from sqlalchemy.pool import QueuePool, SingletonThreadPool +from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint, + CreateIndex) from sqlalchemy.sql import select, and_ import ConfigParser import os @@ -13,17 +15,36 @@ import uuid import logging -CURRENT_SCHEMA_VERSION = 1 -OPTIONS_COLUMNS = ['name', 'option', 'value'] -UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value'] +CURRENT_SCHEMA_VERSION = 2 +OPTIONS_TABLE = {'columns': ['name', 'option', 'value'], + 'primary_key': ('name', 'option'), + 'indexes': [('name',)] + } +UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'], + 'primary_key': ('uuid', 'name'), + 'indexes': [('uuid',)] + } -class SqlStore(Log): +class DatabaseError(Exception): + pass + + +class BaseStore(Log): + # Some helper functions used for upgrades + def add_constraint(self, table): + raise NotImplementedError() + + def add_index(self, index): + raise NotImplementedError() + + +class SqlStore(BaseStore): __instances = {} @classmethod def get_connection(cls, name): - if name not in cls.__instances.keys(): + if name not in cls.__instances: if cherrypy.config.get('db.conn.log', False): logging.debug('SqlStore new: %s', name) cls.__instances[name] = SqlStore(name) @@ -50,6 +71,18 @@ class SqlStore(Log): self._dbengine = create_engine(engine_name, **pool_args) self.is_readonly = False + def add_constraint(self, constraint): + if self._dbengine.dialect.name != 'sqlite': + # It is impossible to add constraints to a pre-existing table for + # SQLite + # source: http://www.sqlite.org/omitted.html + create_constraint = AddConstraint(constraint, bind=self._dbengine) + create_constraint.execute() + + def add_index(self, index): + add_index = CreateIndex(index, bind=self._dbengine) + add_index.execute() + def debug(self, fact): if self.db_conn_log: super(SqlStore, self).debug(fact) @@ -68,25 +101,29 @@ class SqlStore(Log): return conn -def SqlAutotable(f): - def at(self, *args, **kwargs): - self.create() - return f(self, *args, **kwargs) - return at - - class SqlQuery(Log): - def __init__(self, db_obj, table, columns, trans=True): + def __init__(self, db_obj, table, table_def, trans=True): self._db = db_obj self._con = self._db.connection() self._trans = self._con.begin() if trans else None - self._table = self._get_table(table, columns) - - def _get_table(self, name, columns): - table = Table(name, MetaData(self._db.engine())) - for c in columns: - table.append_column(Column(c, Text())) + self._table = self._get_table(table, table_def) + + def _get_table(self, name, table_def): + if isinstance(table_def, list): + table_def = {'columns': table_def, + 'indexes': [], + 'primary_key': None} + table_creation = [] + for col_name in table_def['columns']: + table_creation.append(Column(col_name, Text())) + if table_def['primary_key']: + table_creation.append(PrimaryKeyConstraint( + *table_def['primary_key'])) + for index in table_def['indexes']: + idx_name = 'idx_%s_%s' % (name, '_'.join(index)) + table_creation.append(Index(idx_name, *index)) + table = Table(name, MetaData(self._db.engine()), *table_creation) return table def _where(self, kvfilter): @@ -122,25 +159,21 @@ class SqlQuery(Log): def drop(self): self._table.drop(checkfirst=True) - @SqlAutotable def select(self, kvfilter=None, columns=None): return self._con.execute(select(self._columns(columns), self._where(kvfilter))) - @SqlAutotable def insert(self, values): self._con.execute(self._table.insert(values)) - @SqlAutotable def update(self, values, kvfilter): self._con.execute(self._table.update(self._where(kvfilter), values)) - @SqlAutotable def delete(self, kvfilter): self._con.execute(self._table.delete(self._where(kvfilter))) -class FileStore(Log): +class FileStore(BaseStore): def __init__(self, name): self._filename = name @@ -163,10 +196,21 @@ class FileStore(Log): self._config.read(self._filename) return self._config + def add_constraint(self, table): + raise NotImplementedError() + + def add_index(self, index): + raise NotImplementedError() + class FileQuery(Log): - def __init__(self, fstore, table, columns, trans=True): + def __init__(self, fstore, table, table_def, trans=True): + # We don't need indexes in a FileQuery, so drop that info + if isinstance(table_def, dict): + columns = table_def['columns'] + else: + columns = table_def self._fstore = fstore self._config = fstore.get_config() self._section = table @@ -253,6 +297,8 @@ class FileQuery(Log): class Store(Log): + _is_upgrade = False + def __init__(self, config_name=None, database_url=None): if config_name is None and database_url is None: raise ValueError('config_name or database_url must be provided') @@ -269,33 +315,102 @@ class Store(Log): else: self._db = SqlStore.get_connection(name) self._query = SqlQuery - self._upgrade_database() - def _upgrade_database(self): + if not self._is_upgrade: + self._check_database() + + def _code_schema_version(self): + # This function makes it possible for separate plugins to have + # different schema versions. We default to the global schema + # version. + return CURRENT_SCHEMA_VERSION + + def _get_schema_version(self): + # We are storing multiple versions: one per class + # That way, we can support plugins with differing schema versions from + # the main codebase, and even in the same database. + q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False) + q.create() + q._con.close() # pylint: disable=protected-access + cls_name = self.__class__.__name__ + current_version = self.load_options('dbinfo').get('%s_schema' + % cls_name, {}) + if 'version' in current_version: + return int(current_version['version']) + else: + # Also try the old table name. + # "scheme" was a typo, but we need to retain that now for compat + fallback_version = self.load_options('dbinfo').get('scheme', + {}) + if 'version' in fallback_version: + # Explanation for this is in def upgrade_database(self) + return -1 + else: + return None + + def _check_database(self): if self.is_readonly: # If the database is readonly, we cannot do anything to the # schema. Let's just return, and assume people checked the # upgrade notes return - current_version = self.load_options('dbinfo').get('scheme', None) - if current_version is None or 'version' not in current_version: - # No version stored, storing current version - self.save_options('dbinfo', 'scheme', - {'version': CURRENT_SCHEMA_VERSION}) - current_version = CURRENT_SCHEMA_VERSION - else: - current_version = int(current_version['version']) - if current_version != CURRENT_SCHEMA_VERSION: - self.debug('Upgrading database schema from %i to %i' % ( - current_version, CURRENT_SCHEMA_VERSION)) - self._upgrade_database_from(current_version) - - def _upgrade_database_from(self, old_schema_version): - # Insert code here to upgrade from old_schema_version to - # CURRENT_SCHEMA_VERSION - raise Exception('Unable to upgrade database to current schema' - ' version: version %i is unknown!' % - old_schema_version) + + current_version = self._get_schema_version() + if current_version is None: + self.error('Database initialization required! ' + + 'Please run ipsilon-upgrade-database') + raise DatabaseError('Database initialization required for %s' % + self.__class__.__name__) + if current_version != self._code_schema_version(): + self.error('Database upgrade required! ' + + 'Please run ipsilon-upgrade-database') + raise DatabaseError('Database upgrade required for %s' % + self.__class__.__name__) + + def _store_new_schema_version(self, new_version): + cls_name = self.__class__.__name__ + self.save_options('dbinfo', '%s_schema' % cls_name, + {'version': new_version}) + + def _initialize_schema(self): + raise NotImplementedError() + + def _upgrade_schema(self, old_version): + # Datastores need to figure out what to do with bigger old_versions + # themselves. + # They might implement downgrading if that's feasible, or just throw + # NotImplementedError + # Should return the new schema version + raise NotImplementedError() + + def upgrade_database(self): + # Do whatever is needed to get schema to current version + old_schema_version = self._get_schema_version() + if old_schema_version is None: + # Just initialize a new schema + self._initialize_schema() + self._store_new_schema_version(self._code_schema_version()) + elif old_schema_version == -1: + # This is a special-case from 1.0: we only created tables at the + # first time they were actually used, but the upgrade code assumes + # that the tables exist. So let's fix this. + self._initialize_schema() + # The old version was schema version 1 + self._store_new_schema_version(1) + self.upgrade_database() + elif old_schema_version != self._code_schema_version(): + # Upgrade from old_schema_version to code_schema_version + self.debug('Upgrading from schema version %i' % old_schema_version) + new_version = self._upgrade_schema(old_schema_version) + if not new_version: + error = ('Schema upgrade error: %s did not provide a ' + + 'new schema version number!' % + self.__class__.__name__) + self.error(error) + raise Exception(error) + self._store_new_schema_version(new_version) + # Check if we are now up-to-date + self.upgrade_database() @property def is_readonly(self): @@ -343,7 +458,7 @@ class Store(Log): kvfilter = dict() if name: kvfilter['name'] = name - options = self._load_data(table, OPTIONS_COLUMNS, kvfilter) + options = self._load_data(table, OPTIONS_TABLE, kvfilter) if name and name in options: return options[name] return options @@ -352,7 +467,7 @@ class Store(Log): curvals = dict() q = None try: - q = self._query(self._db, table, OPTIONS_COLUMNS) + q = self._query(self._db, table, OPTIONS_TABLE) rows = q.select({'name': name}, ['option', 'value']) for row in rows: curvals[row[0]] = row[1] @@ -375,7 +490,7 @@ class Store(Log): kvfilter = {'name': name} q = None try: - q = self._query(self._db, table, OPTIONS_COLUMNS) + q = self._query(self._db, table, OPTIONS_TABLE) if options is None: q.delete(kvfilter) else: @@ -393,7 +508,7 @@ class Store(Log): newid = str(uuid.uuid4()) q = None try: - q = self._query(self._db, table, UNIQUE_DATA_COLUMNS) + q = self._query(self._db, table, UNIQUE_DATA_TABLE) for name in data: q.insert((newid, name, data[name])) q.commit() @@ -412,12 +527,12 @@ class Store(Log): kvfilter['name'] = name if value: kvfilter['value'] = value - return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter) + return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter) def save_unique_data(self, table, data): q = None try: - q = self._query(self._db, table, UNIQUE_DATA_COLUMNS) + q = self._query(self._db, table, UNIQUE_DATA_TABLE) for uid in data: curvals = dict() rows = q.select({'uuid': uid}, ['name', 'value']) @@ -446,7 +561,7 @@ class Store(Log): def del_unique_data(self, table, uuidval): kvfilter = {'uuid': uuidval} try: - q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False) + q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False) q.delete(kvfilter) except Exception, e: # pylint: disable=broad-except self.error("Failed to delete data from %s: [%s]" % (table, e)) @@ -454,7 +569,7 @@ class Store(Log): def _reset_data(self, table): q = None try: - q = self._query(self._db, table, UNIQUE_DATA_COLUMNS) + q = self._query(self._db, table, UNIQUE_DATA_TABLE) q.drop() q.create() q.commit() @@ -487,6 +602,40 @@ class AdminStore(Store): table = plugin+"_data" self._reset_data(table) + def _initialize_schema(self): + for table in ['config', + 'info_config', + 'login_config', + 'provider_config']: + q = self._query(self._db, table, OPTIONS_TABLE, trans=False) + q.create() + q._con.close() # pylint: disable=protected-access + + def _upgrade_schema(self, old_version): + if old_version == 1: + # In schema version 2, we added indexes and primary keys + for table in ['config', + 'info_config', + 'login_config', + 'provider_config']: + # pylint: disable=protected-access + table = self._query(self._db, table, OPTIONS_TABLE, + trans=False)._table + self._db.add_constraint(table.primary_key) + for index in table.indexes: + self._db.add_index(index) + return 2 + else: + raise NotImplementedError() + + def create_plugin_data_table(self, plugin_name): + if not self.is_readonly: + table = plugin_name+'_data' + q = self._query(self._db, table, UNIQUE_DATA_TABLE, + trans=False) + q.create() + q._con.close() # pylint: disable=protected-access + class UserStore(Store): @@ -505,20 +654,57 @@ class UserStore(Store): def load_plugin_data(self, plugin, user): return self.load_options(plugin+"_data", user) + def _initialize_schema(self): + q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False) + q.create() + q._con.close() # pylint: disable=protected-access + + def _upgrade_schema(self, old_version): + if old_version == 1: + # In schema version 2, we added indexes and primary keys + # pylint: disable=protected-access + table = self._query(self._db, 'users', OPTIONS_TABLE, + trans=False)._table + self._db.add_constraint(table.primary_key) + for index in table.indexes: + self._db.add_index(index) + return 2 + else: + raise NotImplementedError() + class TranStore(Store): def __init__(self, path=None): super(TranStore, self).__init__('transactions.db') + def _initialize_schema(self): + q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE, + trans=False) + q.create() + q._con.close() # pylint: disable=protected-access + + def _upgrade_schema(self, old_version): + if old_version == 1: + # In schema version 2, we added indexes and primary keys + # pylint: disable=protected-access + table = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE, + trans=False)._table + self._db.add_constraint(table.primary_key) + for index in table.indexes: + self._db.add_index(index) + return 2 + else: + raise NotImplementedError() + class SAML2SessionStore(Store): def __init__(self, database_url): super(SAML2SessionStore, self).__init__(database_url=database_url) - self.table = 'sessions' + self.table = 'saml2_sessions' # pylint: disable=protected-access - table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table + table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table table.create(checkfirst=True) def _get_unique_id_from_column(self, name, value): @@ -539,7 +725,7 @@ class SAML2SessionStore(Store): def remove_expired_sessions(self): # pylint: disable=protected-access - table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table + table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table sel = select([table.columns.uuid]). \ where(and_(table.c.name == 'expiration_time', table.c.value <= datetime.datetime.now())) @@ -551,6 +737,10 @@ class SAML2SessionStore(Store): return self.get_unique_data(self.table, idval, name, value) def new_session(self, datum): + if 'supported_logout_mechs' in datum: + datum['supported_logout_mechs'] = ','.join( + datum['supported_logout_mechs'] + ) return self.new_unique_data(self.table, datum) def get_session(self, session_id=None, request_id=None): @@ -567,7 +757,7 @@ class SAML2SessionStore(Store): def get_user_sessions(self, user): """ - Retrun a list of all sessions for a given user. + Return a list of all sessions for a given user. """ rows = self.get_unique_data(self.table, name='user', value=user) @@ -575,6 +765,8 @@ class SAML2SessionStore(Store): logged_in = [] for r in rows: data = self.get_unique_data(self.table, uuidval=r) + data[r]['supported_logout_mechs'] = data[r].get( + 'supported_logout_mechs', '').split(',') logged_in.append(data) return logged_in @@ -587,3 +779,22 @@ class SAML2SessionStore(Store): def wipe_data(self): self._reset_data(self.table) + + def _initialize_schema(self): + q = self._query(self._db, self.table, UNIQUE_DATA_TABLE, + trans=False) + q.create() + q._con.close() # pylint: disable=protected-access + + def _upgrade_schema(self, old_version): + if old_version == 1: + # In schema version 2, we added indexes and primary keys + # pylint: disable=protected-access + table = self._query(self._db, self.table, UNIQUE_DATA_TABLE, + trans=False)._table + self._db.add_constraint(table.primary_key) + for index in table.indexes: + self._db.add_index(index) + return 2 + else: + raise NotImplementedError()