1 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
5 from ipsilon.util.log import Log
6 from sqlalchemy import create_engine
7 from sqlalchemy import MetaData, Table, Column, Text
8 from sqlalchemy.pool import QueuePool, SingletonThreadPool
9 from sqlalchemy.sql import select, and_
16 CURRENT_SCHEMA_VERSION = 1
17 OPTIONS_COLUMNS = ['name', 'option', 'value']
18 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
25 def get_connection(cls, name):
26 if name not in cls.__instances.keys():
27 if cherrypy.config.get('db.conn.log', False):
28 logging.debug('SqlStore new: %s', name)
29 cls.__instances[name] = SqlStore(name)
30 return cls.__instances[name]
32 def __init__(self, name):
33 self.db_conn_log = cherrypy.config.get('db.conn.log', False)
34 self.debug('SqlStore init: %s' % name)
37 if '://' not in engine_name:
38 engine_name = 'sqlite:///' + engine_name
39 # This pool size is per configured database. The minimum needed,
40 # determined by binary search, is 23. We're using 25 so we have a bit
41 # more playroom, and then the overflow should make sure things don't
42 # break when we suddenly need more.
43 pool_args = {'poolclass': QueuePool,
46 if engine_name.startswith('sqlite://'):
47 # It's not possible to share connections for SQLite between
48 # threads, so let's use the SingletonThreadPool for them
49 pool_args = {'poolclass': SingletonThreadPool}
50 self._dbengine = create_engine(engine_name, **pool_args)
51 self.is_readonly = False
53 def debug(self, fact):
55 super(SqlStore, self).debug(fact)
61 self.debug('SqlStore connect: %s' % self.name)
62 conn = self._dbengine.connect()
64 def cleanup_connection():
65 self.debug('SqlStore cleanup: %s' % self.name)
67 cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
72 def at(self, *args, **kwargs):
74 return f(self, *args, **kwargs)
80 def __init__(self, db_obj, table, columns, trans=True):
82 self._con = self._db.connection()
83 self._trans = self._con.begin() if trans else None
84 self._table = self._get_table(table, columns)
86 def _get_table(self, name, columns):
87 table = Table(name, MetaData(self._db.engine()))
89 table.append_column(Column(c, Text()))
92 def _where(self, kvfilter):
94 if kvfilter is not None:
96 w = self._table.columns[k] == kvfilter[k]
103 def _columns(self, columns=None):
105 if columns is not None:
108 cols.append(self._table.columns[c])
110 cols = self._table.columns
114 self._trans.rollback()
120 self._table.create(checkfirst=True)
123 self._table.drop(checkfirst=True)
126 def select(self, kvfilter=None, columns=None):
127 return self._con.execute(select(self._columns(columns),
128 self._where(kvfilter)))
131 def insert(self, values):
132 self._con.execute(self._table.insert(values))
135 def update(self, values, kvfilter):
136 self._con.execute(self._table.update(self._where(kvfilter), values))
139 def delete(self, kvfilter):
140 self._con.execute(self._table.delete(self._where(kvfilter)))
143 class FileStore(Log):
145 def __init__(self, name):
146 self._filename = name
147 self.is_readonly = True
148 self._timestamp = None
151 def get_config(self):
153 stat = os.stat(self._filename)
155 self.error("Unable to check config file %s: [%s]" % (
159 timestamp = stat.st_mtime
160 if self._config is None or timestamp > self._timestamp:
161 self._config = ConfigParser.RawConfigParser()
162 self._config.optionxform = str
163 self._config.read(self._filename)
167 class FileQuery(Log):
169 def __init__(self, fstore, table, columns, trans=True):
170 self._fstore = fstore
171 self._config = fstore.get_config()
172 self._section = table
173 if len(columns) > 3 or columns[-1] != 'value':
174 raise ValueError('Unsupported configuration format')
175 self._columns = columns
184 raise NotImplementedError
187 raise NotImplementedError
189 def select(self, kvfilter=None, columns=None):
190 if self._section not in self._config.sections():
193 opts = self._config.options(self._section)
197 if self._columns[0] in kvfilter:
198 prefix = kvfilter[self._columns[0]]
199 prefix_ = prefix + ' '
202 if len(self._columns) == 3 and self._columns[1] in kvfilter:
203 name = kvfilter[self._columns[1]]
206 if self._columns[-1] in kvfilter:
207 value = kvfilter[self._columns[-1]]
211 if len(self._columns) == 3:
213 if prefix and not o.startswith(prefix_):
216 col1, col2 = o.split(' ', 1)
217 if name and col2 != name:
220 col3 = self._config.get(self._section, o)
221 if value and col3 != value:
224 r = [col1, col2, col3]
227 if prefix and o != prefix:
229 r = [o, self._config.get(self._section, o)]
234 s.append(r[self._columns.index(c)])
239 self.debug('SELECT(%s, %s, %s) -> %s' % (self._section,
245 def insert(self, values):
246 raise NotImplementedError
248 def update(self, values, kvfilter):
249 raise NotImplementedError
251 def delete(self, kvfilter):
252 raise NotImplementedError
256 def __init__(self, config_name=None, database_url=None):
257 if config_name is None and database_url is None:
258 raise ValueError('config_name or database_url must be provided')
260 if config_name not in cherrypy.config:
261 raise NameError('Unknown database %s' % config_name)
262 name = cherrypy.config[config_name]
265 if name.startswith('configfile://'):
266 _, filename = name.split('://')
267 self._db = FileStore(filename)
268 self._query = FileQuery
270 self._db = SqlStore.get_connection(name)
271 self._query = SqlQuery
272 self._upgrade_database()
274 def _upgrade_database(self):
276 # If the database is readonly, we cannot do anything to the
277 # schema. Let's just return, and assume people checked the
280 current_version = self.load_options('dbinfo').get('scheme', None)
281 if current_version is None or 'version' not in current_version:
282 # No version stored, storing current version
283 self.save_options('dbinfo', 'scheme',
284 {'version': CURRENT_SCHEMA_VERSION})
285 current_version = CURRENT_SCHEMA_VERSION
287 current_version = int(current_version['version'])
288 if current_version != CURRENT_SCHEMA_VERSION:
289 self.debug('Upgrading database schema from %i to %i' % (
290 current_version, CURRENT_SCHEMA_VERSION))
291 self._upgrade_database_from(current_version)
293 def _upgrade_database_from(self, old_schema_version):
294 # Insert code here to upgrade from old_schema_version to
295 # CURRENT_SCHEMA_VERSION
296 raise Exception('Unable to upgrade database to current schema'
297 ' version: version %i is unknown!' %
301 def is_readonly(self):
302 return self._db.is_readonly
304 def _row_to_dict_tree(self, data, row):
310 self._row_to_dict_tree(d2, row[1:])
314 if data[name] is list:
315 data[name].append(value)
318 data[name] = [v, value]
322 def _rows_to_dict_tree(self, rows):
325 self._row_to_dict_tree(data, r)
328 def _load_data(self, table, columns, kvfilter=None):
331 q = self._query(self._db, table, columns, trans=False)
332 rows = q.select(kvfilter)
333 except Exception, e: # pylint: disable=broad-except
334 self.error("Failed to load data for table %s: [%s]" % (table, e))
335 return self._rows_to_dict_tree(rows)
337 def load_config(self):
339 columns = ['name', 'value']
340 return self._load_data(table, columns)
342 def load_options(self, table, name=None):
345 kvfilter['name'] = name
346 options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
347 if name and name in options:
351 def save_options(self, table, name, options):
355 q = self._query(self._db, table, OPTIONS_COLUMNS)
356 rows = q.select({'name': name}, ['option', 'value'])
358 curvals[row[0]] = row[1]
362 q.update({'value': options[opt]},
363 {'name': name, 'option': opt})
365 q.insert((name, opt, options[opt]))
368 except Exception, e: # pylint: disable=broad-except
371 self.error("Failed to save options: [%s]" % e)
374 def delete_options(self, table, name, options=None):
375 kvfilter = {'name': name}
378 q = self._query(self._db, table, OPTIONS_COLUMNS)
383 kvfilter['option'] = opt
386 except Exception, e: # pylint: disable=broad-except
389 self.error("Failed to delete from %s: [%s]" % (table, e))
392 def new_unique_data(self, table, data):
393 newid = str(uuid.uuid4())
396 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
398 q.insert((newid, name, data[name]))
400 except Exception, e: # pylint: disable=broad-except
403 self.error("Failed to store %s data: [%s]" % (table, e))
407 def get_unique_data(self, table, uuidval=None, name=None, value=None):
410 kvfilter['uuid'] = uuidval
412 kvfilter['name'] = name
414 kvfilter['value'] = value
415 return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
417 def save_unique_data(self, table, data):
420 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
423 rows = q.select({'uuid': uid}, ['name', 'value'])
430 if datum[name] is None:
431 q.delete({'uuid': uid, 'name': name})
433 q.update({'value': datum[name]},
434 {'uuid': uid, 'name': name})
436 if datum[name] is not None:
437 q.insert((uid, name, datum[name]))
440 except Exception, e: # pylint: disable=broad-except
443 self.error("Failed to store data in %s: [%s]" % (table, e))
446 def del_unique_data(self, table, uuidval):
447 kvfilter = {'uuid': uuidval}
449 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
451 except Exception, e: # pylint: disable=broad-except
452 self.error("Failed to delete data from %s: [%s]" % (table, e))
454 def _reset_data(self, table):
457 q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
461 except Exception, e: # pylint: disable=broad-except
464 self.error("Failed to erase all data from %s: [%s]" % (table, e))
467 class AdminStore(Store):
470 super(AdminStore, self).__init__('admin.config.db')
472 def get_data(self, plugin, idval=None, name=None, value=None):
473 return self.get_unique_data(plugin+"_data", idval, name, value)
475 def save_data(self, plugin, data):
476 return self.save_unique_data(plugin+"_data", data)
478 def new_datum(self, plugin, datum):
479 table = plugin+"_data"
480 return self.new_unique_data(table, datum)
482 def del_datum(self, plugin, idval):
483 table = plugin+"_data"
484 return self.del_unique_data(table, idval)
486 def wipe_data(self, plugin):
487 table = plugin+"_data"
488 self._reset_data(table)
491 class UserStore(Store):
493 def __init__(self, path=None):
494 super(UserStore, self).__init__('user.prefs.db')
496 def save_user_preferences(self, user, options):
497 self.save_options('users', user, options)
499 def load_user_preferences(self, user):
500 return self.load_options('users', user)
502 def save_plugin_data(self, plugin, user, options):
503 self.save_options(plugin+"_data", user, options)
505 def load_plugin_data(self, plugin, user):
506 return self.load_options(plugin+"_data", user)
509 class TranStore(Store):
511 def __init__(self, path=None):
512 super(TranStore, self).__init__('transactions.db')
515 class SAML2SessionStore(Store):
517 def __init__(self, database_url):
518 super(SAML2SessionStore, self).__init__(database_url=database_url)
519 self.table = 'sessions'
520 # pylint: disable=protected-access
521 table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
522 table.create(checkfirst=True)
524 def _get_unique_id_from_column(self, name, value):
526 The query is going to return only the column in the query.
527 Use this method to get the uuidval which can be used to fetch
530 Returns None or the uuid of the first value found.
532 data = self.get_unique_data(self.table, name=name, value=value)
537 raise ValueError("Multiple entries returned")
538 return data.keys()[0]
540 def remove_expired_sessions(self):
541 # pylint: disable=protected-access
542 table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
543 sel = select([table.columns.uuid]). \
544 where(and_(table.c.name == 'expiration_time',
545 table.c.value <= datetime.datetime.now()))
546 # pylint: disable=no-value-for-parameter
547 d = table.delete().where(table.c.uuid.in_(sel))
550 def get_data(self, idval=None, name=None, value=None):
551 return self.get_unique_data(self.table, idval, name, value)
553 def new_session(self, datum):
554 return self.new_unique_data(self.table, datum)
556 def get_session(self, session_id=None, request_id=None):
558 uuidval = self._get_unique_id_from_column('session_id', session_id)
560 uuidval = self._get_unique_id_from_column('request_id', request_id)
562 raise ValueError("Unable to find session")
565 data = self.get_unique_data(self.table, uuidval=uuidval)
566 return uuidval, data[uuidval]
568 def get_user_sessions(self, user):
570 Retrun a list of all sessions for a given user.
572 rows = self.get_unique_data(self.table, name='user', value=user)
574 # We have a list of sessions for this user, now get the details
577 data = self.get_unique_data(self.table, uuidval=r)
578 logged_in.append(data)
582 def update_session(self, datum):
583 self.save_unique_data(self.table, datum)
585 def remove_session(self, uuidval):
586 self.del_unique_data(self.table, uuidval)
589 self._reset_data(self.table)