1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.schema import (PrimaryKeyConstraint, Index, AddConstraint,
11 from sqlalchemy.sql import select, and_
18 CURRENT_SCHEMA_VERSION = 2
19 OPTIONS_TABLE = {'columns': ['name', 'option', 'value'],
20 'primary_key': ('name', 'option'),
21 'indexes': [('name',)]
23 UNIQUE_DATA_TABLE = {'columns': ['uuid', 'name', 'value'],
24 'primary_key': ('uuid', 'name'),
25 'indexes': [('uuid',)]
29 class DatabaseError(Exception):
34 # Some helper functions used for upgrades
35 def add_constraint(self, table):
36 raise NotImplementedError()
38 def add_index(self, index):
39 raise NotImplementedError()
42 class SqlStore(BaseStore):
46 def get_connection(cls, name):
47 if name not in cls.__instances:
48 if cherrypy.config.get('db.conn.log', False):
49 logging.debug('SqlStore new: %s', name)
50 cls.__instances[name] = SqlStore(name)
51 return cls.__instances[name]
53 def __init__(self, name):
54 self.db_conn_log = cherrypy.config.get('db.conn.log', False)
55 self.debug('SqlStore init: %s' % name)
58 if '://' not in engine_name:
59 engine_name = 'sqlite:///' + engine_name
60 # This pool size is per configured database. The minimum needed,
61 # determined by binary search, is 23. We're using 25 so we have a bit
62 # more playroom, and then the overflow should make sure things don't
63 # break when we suddenly need more.
64 pool_args = {'poolclass': QueuePool,
67 if engine_name.startswith('sqlite://'):
68 # It's not possible to share connections for SQLite between
69 # threads, so let's use the SingletonThreadPool for them
70 pool_args = {'poolclass': SingletonThreadPool}
71 self._dbengine = create_engine(engine_name, **pool_args)
72 self.is_readonly = False
74 def add_constraint(self, constraint):
75 if self._dbengine.dialect.name != 'sqlite':
76 # It is impossible to add constraints to a pre-existing table for
78 # source: http://www.sqlite.org/omitted.html
79 create_constraint = AddConstraint(constraint, bind=self._dbengine)
80 create_constraint.execute()
82 def add_index(self, index):
83 add_index = CreateIndex(index, bind=self._dbengine)
86 def debug(self, fact):
88 super(SqlStore, self).debug(fact)
94 self.debug('SqlStore connect: %s' % self.name)
95 conn = self._dbengine.connect()
97 def cleanup_connection():
98 self.debug('SqlStore cleanup: %s' % self.name)
100 cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
106 def __init__(self, db_obj, table, table_def, trans=True):
108 self._con = self._db.connection()
109 self._trans = self._con.begin() if trans else None
110 self._table = self._get_table(table, table_def)
112 def _get_table(self, name, table_def):
113 if isinstance(table_def, list):
114 table_def = {'columns': table_def,
118 for col_name in table_def['columns']:
119 table_creation.append(Column(col_name, Text()))
120 if table_def['primary_key']:
121 table_creation.append(PrimaryKeyConstraint(
122 *table_def['primary_key']))
123 for index in table_def['indexes']:
124 idx_name = 'idx_%s_%s' % (name, '_'.join(index))
125 table_creation.append(Index(idx_name, *index))
126 table = Table(name, MetaData(self._db.engine()), *table_creation)
129 def _where(self, kvfilter):
131 if kvfilter is not None:
133 w = self._table.columns[k] == kvfilter[k]
140 def _columns(self, columns=None):
142 if columns is not None:
145 cols.append(self._table.columns[c])
147 cols = self._table.columns
151 self._trans.rollback()
157 self._table.create(checkfirst=True)
160 self._table.drop(checkfirst=True)
162 def select(self, kvfilter=None, columns=None):
163 return self._con.execute(select(self._columns(columns),
164 self._where(kvfilter)))
166 def insert(self, values):
167 self._con.execute(self._table.insert(values))
169 def update(self, values, kvfilter):
170 self._con.execute(self._table.update(self._where(kvfilter), values))
172 def delete(self, kvfilter):
173 self._con.execute(self._table.delete(self._where(kvfilter)))
176 class FileStore(BaseStore):
178 def __init__(self, name):
179 self._filename = name
180 self.is_readonly = True
181 self._timestamp = None
184 def get_config(self):
186 stat = os.stat(self._filename)
188 self.error("Unable to check config file %s: [%s]" % (
192 timestamp = stat.st_mtime
193 if self._config is None or timestamp > self._timestamp:
194 self._config = ConfigParser.RawConfigParser()
195 self._config.optionxform = str
196 self._config.read(self._filename)
199 def add_constraint(self, table):
200 raise NotImplementedError()
202 def add_index(self, index):
203 raise NotImplementedError()
206 class FileQuery(Log):
208 def __init__(self, fstore, table, table_def, trans=True):
209 # We don't need indexes in a FileQuery, so drop that info
210 if isinstance(table_def, dict):
211 columns = table_def['columns']
214 self._fstore = fstore
215 self._config = fstore.get_config()
216 self._section = table
217 if len(columns) > 3 or columns[-1] != 'value':
218 raise ValueError('Unsupported configuration format')
219 self._columns = columns
228 raise NotImplementedError
231 raise NotImplementedError
233 def select(self, kvfilter=None, columns=None):
234 if self._section not in self._config.sections():
237 opts = self._config.options(self._section)
241 if self._columns[0] in kvfilter:
242 prefix = kvfilter[self._columns[0]]
243 prefix_ = prefix + ' '
246 if len(self._columns) == 3 and self._columns[1] in kvfilter:
247 name = kvfilter[self._columns[1]]
250 if self._columns[-1] in kvfilter:
251 value = kvfilter[self._columns[-1]]
255 if len(self._columns) == 3:
257 if prefix and not o.startswith(prefix_):
260 col1, col2 = o.split(' ', 1)
261 if name and col2 != name:
264 col3 = self._config.get(self._section, o)
265 if value and col3 != value:
268 r = [col1, col2, col3]
271 if prefix and o != prefix:
273 r = [o, self._config.get(self._section, o)]
278 s.append(r[self._columns.index(c)])
283 self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
289 def insert(self, values):
290 raise NotImplementedError
292 def update(self, values, kvfilter):
293 raise NotImplementedError
295 def delete(self, kvfilter):
296 raise NotImplementedError
302 def __init__(self, config_name=None, database_url=None):
303 if config_name is None and database_url is None:
304 raise ValueError('config_name or database_url must be provided')
306 if config_name not in cherrypy.config:
307 raise NameError('Unknown database %s' % config_name)
308 name = cherrypy.config[config_name]
311 if name.startswith('configfile://'):
312 _, filename = name.split('://')
313 self._db = FileStore(filename)
314 self._query = FileQuery
316 self._db = SqlStore.get_connection(name)
317 self._query = SqlQuery
319 if not self._is_upgrade:
320 self._check_database()
322 def _code_schema_version(self):
323 # This function makes it possible for separate plugins to have
324 # different schema versions. We default to the global schema
326 return CURRENT_SCHEMA_VERSION
328 def _get_schema_version(self):
329 # We are storing multiple versions: one per class
330 # That way, we can support plugins with differing schema versions from
331 # the main codebase, and even in the same database.
332 q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)
334 cls_name = self.__class__.__name__
335 current_version = self.load_options('dbinfo').get('%s_schema'
337 if 'version' in current_version:
338 return int(current_version['version'])
340 # Also try the old table name.
341 # "scheme" was a typo, but we need to retain that now for compat
342 fallback_version = self.load_options('dbinfo').get('scheme',
344 if 'version' in fallback_version:
345 return int(fallback_version['version'])
349 def _check_database(self):
351 # If the database is readonly, we cannot do anything to the
352 # schema. Let's just return, and assume people checked the
356 current_version = self._get_schema_version()
357 if current_version is None:
358 self.error('Database initialization required! ' +
359 'Please run ipsilon-upgrade-database')
360 raise DatabaseError('Database initialization required for %s' %
361 self.__class__.__name__)
362 if current_version != self._code_schema_version():
363 self.error('Database upgrade required! ' +
364 'Please run ipsilon-upgrade-database')
365 raise DatabaseError('Database upgrade required for %s' %
366 self.__class__.__name__)
368 def _store_new_schema_version(self, new_version):
369 cls_name = self.__class__.__name__
370 self.save_options('dbinfo', '%s_schema' % cls_name,
371 {'version': new_version})
373 def _initialize_schema(self):
374 raise NotImplementedError()
376 def _upgrade_schema(self, old_version):
377 # Datastores need to figure out what to do with bigger old_versions
379 # They might implement downgrading if that's feasible, or just throw
380 # NotImplementedError
381 # Should return the new schema version
382 raise NotImplementedError()
384 def upgrade_database(self):
385 # Do whatever is needed to get schema to current version
386 old_schema_version = self._get_schema_version()
387 if old_schema_version is None:
388 # Just initialize a new schema
389 self._initialize_schema()
390 self._store_new_schema_version(self._code_schema_version())
391 elif old_schema_version != self._code_schema_version():
392 # Upgrade from old_schema_version to code_schema_version
393 self.debug('Upgrading from schema version %i' % old_schema_version)
394 new_version = self._upgrade_schema(old_schema_version)
396 error = ('Schema upgrade error: %s did not provide a ' +
397 'new schema version number!' %
398 self.__class__.__name__)
400 raise Exception(error)
401 self._store_new_schema_version(new_version)
402 # Check if we are now up-to-date
403 self.upgrade_database()
406 def is_readonly(self):
407 return self._db.is_readonly
409 def _row_to_dict_tree(self, data, row):
415 self._row_to_dict_tree(d2, row[1:])
419 if data[name] is list:
420 data[name].append(value)
423 data[name] = [v, value]
427 def _rows_to_dict_tree(self, rows):
430 self._row_to_dict_tree(data, r)
433 def _load_data(self, table, columns, kvfilter=None):
436 q = self._query(self._db, table, columns, trans=False)
437 rows = q.select(kvfilter)
438 except Exception, e: # pylint: disable=broad-except
439 self.error("Failed to load data for table %s: [%s]" % (table, e))
440 return self._rows_to_dict_tree(rows)
442 def load_config(self):
444 columns = ['name', 'value']
445 return self._load_data(table, columns)
447 def load_options(self, table, name=None):
450 kvfilter['name'] = name
451 options = self._load_data(table, OPTIONS_TABLE, kvfilter)
452 if name and name in options:
456 def save_options(self, table, name, options):
460 q = self._query(self._db, table, OPTIONS_TABLE)
461 rows = q.select({'name': name}, ['option', 'value'])
463 curvals[row[0]] = row[1]
467 q.update({'value': options[opt]},
468 {'name': name, 'option': opt})
470 q.insert((name, opt, options[opt]))
473 except Exception, e: # pylint: disable=broad-except
476 self.error("Failed to save options: [%s]" % e)
479 def delete_options(self, table, name, options=None):
480 kvfilter = {'name': name}
483 q = self._query(self._db, table, OPTIONS_TABLE)
488 kvfilter['option'] = opt
491 except Exception, e: # pylint: disable=broad-except
494 self.error("Failed to delete from %s: [%s]" % (table, e))
497 def new_unique_data(self, table, data):
498 newid = str(uuid.uuid4())
501 q = self._query(self._db, table, UNIQUE_DATA_TABLE)
503 q.insert((newid, name, data[name]))
505 except Exception, e: # pylint: disable=broad-except
508 self.error("Failed to store %s data: [%s]" % (table, e))
512 def get_unique_data(self, table, uuidval=None, name=None, value=None):
515 kvfilter['uuid'] = uuidval
517 kvfilter['name'] = name
519 kvfilter['value'] = value
520 return self._load_data(table, UNIQUE_DATA_TABLE, kvfilter)
522 def save_unique_data(self, table, data):
525 q = self._query(self._db, table, UNIQUE_DATA_TABLE)
528 rows = q.select({'uuid': uid}, ['name', 'value'])
535 if datum[name] is None:
536 q.delete({'uuid': uid, 'name': name})
538 q.update({'value': datum[name]},
539 {'uuid': uid, 'name': name})
541 if datum[name] is not None:
542 q.insert((uid, name, datum[name]))
545 except Exception, e: # pylint: disable=broad-except
548 self.error("Failed to store data in %s: [%s]" % (table, e))
551 def del_unique_data(self, table, uuidval):
552 kvfilter = {'uuid': uuidval}
554 q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)
556 except Exception, e: # pylint: disable=broad-except
557 self.error("Failed to delete data from %s: [%s]" % (table, e))
559 def _reset_data(self, table):
562 q = self._query(self._db, table, UNIQUE_DATA_TABLE)
566 except Exception, e: # pylint: disable=broad-except
569 self.error("Failed to erase all data from %s: [%s]" % (table, e))
572 class AdminStore(Store):
575 super(AdminStore, self).__init__('admin.config.db')
577 def get_data(self, plugin, idval=None, name=None, value=None):
578 return self.get_unique_data(plugin+"_data", idval, name, value)
580 def save_data(self, plugin, data):
581 return self.save_unique_data(plugin+"_data", data)
583 def new_datum(self, plugin, datum):
584 table = plugin+"_data"
585 return self.new_unique_data(table, datum)
587 def del_datum(self, plugin, idval):
588 table = plugin+"_data"
589 return self.del_unique_data(table, idval)
591 def wipe_data(self, plugin):
592 table = plugin+"_data"
593 self._reset_data(table)
595 def _initialize_schema(self):
596 for table in ['config',
600 q = self._query(self._db, table, OPTIONS_TABLE, trans=False)
603 def _upgrade_schema(self, old_version):
605 # In schema version 2, we added indexes and primary keys
606 for table in ['config',
610 # pylint: disable=protected-access
611 table = self._query(self._db, table, OPTIONS_TABLE,
613 self._db.add_constraint(table.primary_key)
614 for index in table.indexes:
615 self._db.add_index(index)
618 raise NotImplementedError()
620 def create_plugin_data_table(self, plugin_name):
621 table = plugin_name+'_data'
622 q = self._query(self._db, table, UNIQUE_DATA_TABLE,
627 class UserStore(Store):
629 def __init__(self, path=None):
630 super(UserStore, self).__init__('user.prefs.db')
632 def save_user_preferences(self, user, options):
633 self.save_options('users', user, options)
635 def load_user_preferences(self, user):
636 return self.load_options('users', user)
638 def save_plugin_data(self, plugin, user, options):
639 self.save_options(plugin+"_data", user, options)
641 def load_plugin_data(self, plugin, user):
642 return self.load_options(plugin+"_data", user)
644 def _initialize_schema(self):
645 q = self._query(self._db, 'users', OPTIONS_TABLE, trans=False)
648 def _upgrade_schema(self, old_version):
650 # In schema version 2, we added indexes and primary keys
651 # pylint: disable=protected-access
652 table = self._query(self._db, 'users', OPTIONS_TABLE,
654 self._db.add_constraint(table.primary_key)
655 for index in table.indexes:
656 self._db.add_index(index)
659 raise NotImplementedError()
662 class TranStore(Store):
664 def __init__(self, path=None):
665 super(TranStore, self).__init__('transactions.db')
667 def _initialize_schema(self):
668 q = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
672 def _upgrade_schema(self, old_version):
674 # In schema version 2, we added indexes and primary keys
675 # pylint: disable=protected-access
676 table = self._query(self._db, 'transactions', UNIQUE_DATA_TABLE,
678 self._db.add_constraint(table.primary_key)
679 for index in table.indexes:
680 self._db.add_index(index)
683 raise NotImplementedError()
686 class SAML2SessionStore(Store):
688 def __init__(self, database_url):
689 super(SAML2SessionStore, self).__init__(database_url=database_url)
690 self.table = 'saml2_sessions'
691 # pylint: disable=protected-access
692 table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
693 table.create(checkfirst=True)
695 def _get_unique_id_from_column(self, name, value):
697 The query is going to return only the column in the query.
698 Use this method to get the uuidval which can be used to fetch
701 Returns None or the uuid of the first value found.
703 data = self.get_unique_data(self.table, name=name, value=value)
708 raise ValueError("Multiple entries returned")
709 return data.keys()[0]
711 def remove_expired_sessions(self):
712 # pylint: disable=protected-access
713 table = SqlQuery(self._db, self.table, UNIQUE_DATA_TABLE)._table
714 sel = select([table.columns.uuid]). \
715 where(and_(table.c.name == 'expiration_time',
716 table.c.value <= datetime.datetime.now()))
717 # pylint: disable=no-value-for-parameter
718 d = table.delete().where(table.c.uuid.in_(sel))
721 def get_data(self, idval=None, name=None, value=None):
722 return self.get_unique_data(self.table, idval, name, value)
724 def new_session(self, datum):
725 if 'supported_logout_mechs' in datum:
726 datum['supported_logout_mechs'] = ','.join(
727 datum['supported_logout_mechs']
729 return self.new_unique_data(self.table, datum)
731 def get_session(self, session_id=None, request_id=None):
733 uuidval = self._get_unique_id_from_column('session_id', session_id)
735 uuidval = self._get_unique_id_from_column('request_id', request_id)
737 raise ValueError("Unable to find session")
740 data = self.get_unique_data(self.table, uuidval=uuidval)
741 return uuidval, data[uuidval]
743 def get_user_sessions(self, user):
745 Return a list of all sessions for a given user.
747 rows = self.get_unique_data(self.table, name='user', value=user)
749 # We have a list of sessions for this user, now get the details
752 data = self.get_unique_data(self.table, uuidval=r)
753 data[r]['supported_logout_mechs'] = data[r].get(
754 'supported_logout_mechs', '').split(',')
755 logged_in.append(data)
759 def update_session(self, datum):
760 self.save_unique_data(self.table, datum)
762 def remove_session(self, uuidval):
763 self.del_unique_data(self.table, uuidval)
766 self._reset_data(self.table)
768 def _initialize_schema(self):
769 q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
773 def _upgrade_schema(self, old_version):
775 # In schema version 2, we added indexes and primary keys
776 # pylint: disable=protected-access
777 table = self._query(self._db, self.table, UNIQUE_DATA_TABLE,
779 self._db.add_constraint(table.primary_key)
780 for index in table.indexes:
781 self._db.add_index(index)
784 raise NotImplementedError()