X-Git-Url: http://git.cascardo.eti.br/?a=blobdiff_plain;f=ipsilon%2Futil%2Fdata.py;h=e0cd6e1726b61cab41054e2fe599a9daeaa2cd4a;hb=2751451f4158417e66974d6415d2da84f612ab3c;hp=94d402bd9aacd4248db36cf00c5582c8e0149e97;hpb=4d234dc3956e8ac72d8e0eccc3c0d9594d1c85f8;p=cascardo%2Fipsilon.git diff --git a/ipsilon/util/data.py b/ipsilon/util/data.py index 94d402b..e0cd6e1 100644 --- a/ipsilon/util/data.py +++ b/ipsilon/util/data.py @@ -1,48 +1,71 @@ -# Copyright (C) 2013 Simo Sorce -# -# see file 'COPYING' for use and warranty information -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . +# Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING import cherrypy +import datetime from ipsilon.util.log import Log from sqlalchemy import create_engine from sqlalchemy import MetaData, Table, Column, Text -from sqlalchemy.sql import select +from sqlalchemy.pool import QueuePool, SingletonThreadPool +from sqlalchemy.sql import select, and_ import ConfigParser import os import uuid +import logging +CURRENT_SCHEMA_VERSION = 1 OPTIONS_COLUMNS = ['name', 'option', 'value'] UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value'] class SqlStore(Log): + __instances = {} + + @classmethod + def get_connection(cls, name): + if name not in cls.__instances.keys(): + if cherrypy.config.get('db.conn.log', False): + logging.debug('SqlStore new: %s', name) + cls.__instances[name] = SqlStore(name) + return cls.__instances[name] def __init__(self, name): + self.db_conn_log = cherrypy.config.get('db.conn.log', False) + self.debug('SqlStore init: %s' % name) + self.name = name engine_name = name if '://' not in engine_name: engine_name = 'sqlite:///' + engine_name - self._dbengine = create_engine(engine_name) + # This pool size is per configured database. The minimum needed, + # determined by binary search, is 23. We're using 25 so we have a bit + # more playroom, and then the overflow should make sure things don't + # break when we suddenly need more. + pool_args = {'poolclass': QueuePool, + 'pool_size': 25, + 'max_overflow': 50} + if engine_name.startswith('sqlite://'): + # It's not possible to share connections for SQLite between + # threads, so let's use the SingletonThreadPool for them + pool_args = {'poolclass': SingletonThreadPool} + self._dbengine = create_engine(engine_name, **pool_args) self.is_readonly = False + def debug(self, fact): + if self.db_conn_log: + super(SqlStore, self).debug(fact) + def engine(self): return self._dbengine def connection(self): - return self._dbengine.connect() + self.debug('SqlStore connect: %s' % self.name) + conn = self._dbengine.connect() + + def cleanup_connection(): + self.debug('SqlStore cleanup: %s' % self.name) + conn.close() + cherrypy.request.hooks.attach('on_end_request', cleanup_connection) + return conn def SqlAutotable(f): @@ -244,8 +267,35 @@ class Store(Log): self._db = FileStore(filename) self._query = FileQuery else: - self._db = SqlStore(name) + self._db = SqlStore.get_connection(name) self._query = SqlQuery + self._upgrade_database() + + def _upgrade_database(self): + if self.is_readonly: + # If the database is readonly, we cannot do anything to the + # schema. Let's just return, and assume people checked the + # upgrade notes + return + current_version = self.load_options('dbinfo').get('scheme', None) + if current_version is None or 'version' not in current_version: + # No version stored, storing current version + self.save_options('dbinfo', 'scheme', + {'version': CURRENT_SCHEMA_VERSION}) + current_version = CURRENT_SCHEMA_VERSION + else: + current_version = int(current_version['version']) + if current_version != CURRENT_SCHEMA_VERSION: + self.debug('Upgrading database schema from %i to %i' % ( + current_version, CURRENT_SCHEMA_VERSION)) + self._upgrade_database_from(current_version) + + def _upgrade_database_from(self, old_schema_version): + # Insert code here to upgrade from old_schema_version to + # CURRENT_SCHEMA_VERSION + raise Exception('Unable to upgrade database to current schema' + ' version: version %i is unknown!' % + old_schema_version) @property def is_readonly(self): @@ -377,10 +427,14 @@ class Store(Log): datum = data[uid] for name in datum: if name in curvals: - q.update({'value': datum[name]}, - {'uuid': uid, 'name': name}) + if datum[name] is None: + q.delete({'uuid': uid, 'name': name}) + else: + q.update({'value': datum[name]}, + {'uuid': uid, 'name': name}) else: - q.insert((uid, name, datum[name])) + if datum[name] is not None: + q.insert((uid, name, datum[name])) q.commit() except Exception, e: # pylint: disable=broad-except @@ -398,6 +452,7 @@ class Store(Log): self.error("Failed to delete data from %s: [%s]" % (table, e)) def _reset_data(self, table): + q = None try: q = self._query(self._db, table, UNIQUE_DATA_COLUMNS) q.drop() @@ -455,3 +510,86 @@ class TranStore(Store): def __init__(self, path=None): super(TranStore, self).__init__('transactions.db') + + +class SAML2SessionStore(Store): + + def __init__(self, database_url): + super(SAML2SessionStore, self).__init__(database_url=database_url) + self.table = 'sessions' + # pylint: disable=protected-access + table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table + table.create(checkfirst=True) + + def _get_unique_id_from_column(self, name, value): + """ + The query is going to return only the column in the query. + Use this method to get the uuidval which can be used to fetch + the entire entry. + + Returns None or the uuid of the first value found. + """ + data = self.get_unique_data(self.table, name=name, value=value) + count = len(data) + if count == 0: + return None + elif count != 1: + raise ValueError("Multiple entries returned") + return data.keys()[0] + + def remove_expired_sessions(self): + # pylint: disable=protected-access + table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table + sel = select([table.columns.uuid]). \ + where(and_(table.c.name == 'expiration_time', + table.c.value <= datetime.datetime.now())) + # pylint: disable=no-value-for-parameter + d = table.delete().where(table.c.uuid.in_(sel)) + d.execute() + + def get_data(self, idval=None, name=None, value=None): + return self.get_unique_data(self.table, idval, name, value) + + def new_session(self, datum): + if 'supported_logout_mechs' in datum: + datum['supported_logout_mechs'] = ','.join( + datum['supported_logout_mechs'] + ) + return self.new_unique_data(self.table, datum) + + def get_session(self, session_id=None, request_id=None): + if session_id: + uuidval = self._get_unique_id_from_column('session_id', session_id) + elif request_id: + uuidval = self._get_unique_id_from_column('request_id', request_id) + else: + raise ValueError("Unable to find session") + if not uuidval: + return None, None + data = self.get_unique_data(self.table, uuidval=uuidval) + return uuidval, data[uuidval] + + def get_user_sessions(self, user): + """ + Return a list of all sessions for a given user. + """ + rows = self.get_unique_data(self.table, name='user', value=user) + + # We have a list of sessions for this user, now get the details + logged_in = [] + for r in rows: + data = self.get_unique_data(self.table, uuidval=r) + data[r]['supported_logout_mechs'] = data[r].get( + 'supported_logout_mechs', '').split(',') + logged_in.append(data) + + return logged_in + + def update_session(self, datum): + self.save_unique_data(self.table, datum) + + def remove_session(self, uuidval): + self.del_unique_data(self.table, uuidval) + + def wipe_data(self): + self._reset_data(self.table)