-# Copyright (C) 2013 Simo Sorce <simo@redhat.com>
-#
-# see file 'COPYING' for use and warranty information
-#
-# This program is free software; you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see <http://www.gnu.org/licenses/>.
+# Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
import cherrypy
+import datetime
from ipsilon.util.log import Log
from sqlalchemy import create_engine
from sqlalchemy import MetaData, Table, Column, Text
-from sqlalchemy.sql import select
+from sqlalchemy.pool import QueuePool, SingletonThreadPool
+from sqlalchemy.sql import select, and_
import ConfigParser
import os
import uuid
+import logging
+CURRENT_SCHEMA_VERSION = 1
OPTIONS_COLUMNS = ['name', 'option', 'value']
UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
class SqlStore(Log):
+ __instances = {}
+
+ @classmethod
+ def get_connection(cls, name):
+ if name not in cls.__instances.keys():
+ if cherrypy.config.get('db.conn.log', False):
+ logging.debug('SqlStore new: %s', name)
+ cls.__instances[name] = SqlStore(name)
+ return cls.__instances[name]
def __init__(self, name):
+ self.db_conn_log = cherrypy.config.get('db.conn.log', False)
+ self.debug('SqlStore init: %s' % name)
+ self.name = name
engine_name = name
if '://' not in engine_name:
engine_name = 'sqlite:///' + engine_name
- self._dbengine = create_engine(engine_name)
+ # This pool size is per configured database. The minimum needed,
+ # determined by binary search, is 23. We're using 25 so we have a bit
+ # more playroom, and then the overflow should make sure things don't
+ # break when we suddenly need more.
+ pool_args = {'poolclass': QueuePool,
+ 'pool_size': 25,
+ 'max_overflow': 50}
+ if engine_name.startswith('sqlite://'):
+ # It's not possible to share connections for SQLite between
+ # threads, so let's use the SingletonThreadPool for them
+ pool_args = {'poolclass': SingletonThreadPool}
+ self._dbengine = create_engine(engine_name, **pool_args)
self.is_readonly = False
+ def debug(self, fact):
+ if self.db_conn_log:
+ super(SqlStore, self).debug(fact)
+
def engine(self):
return self._dbengine
def connection(self):
- return self._dbengine.connect()
+ self.debug('SqlStore connect: %s' % self.name)
+ conn = self._dbengine.connect()
+
+ def cleanup_connection():
+ self.debug('SqlStore cleanup: %s' % self.name)
+ conn.close()
+ cherrypy.request.hooks.attach('on_end_request', cleanup_connection)
+ return conn
def SqlAutotable(f):
self._db = FileStore(filename)
self._query = FileQuery
else:
- self._db = SqlStore(name)
+ self._db = SqlStore.get_connection(name)
self._query = SqlQuery
+ self._upgrade_database()
+
+ def _upgrade_database(self):
+ if self.is_readonly:
+ # If the database is readonly, we cannot do anything to the
+ # schema. Let's just return, and assume people checked the
+ # upgrade notes
+ return
+ current_version = self.load_options('dbinfo').get('scheme', None)
+ if current_version is None or 'version' not in current_version:
+ # No version stored, storing current version
+ self.save_options('dbinfo', 'scheme',
+ {'version': CURRENT_SCHEMA_VERSION})
+ current_version = CURRENT_SCHEMA_VERSION
+ else:
+ current_version = int(current_version['version'])
+ if current_version != CURRENT_SCHEMA_VERSION:
+ self.debug('Upgrading database schema from %i to %i' % (
+ current_version, CURRENT_SCHEMA_VERSION))
+ self._upgrade_database_from(current_version)
+
+ def _upgrade_database_from(self, old_schema_version):
+ # Insert code here to upgrade from old_schema_version to
+ # CURRENT_SCHEMA_VERSION
+ raise Exception('Unable to upgrade database to current schema'
+ ' version: version %i is unknown!' %
+ old_schema_version)
@property
def is_readonly(self):
datum = data[uid]
for name in datum:
if name in curvals:
- q.update({'value': datum[name]},
- {'uuid': uid, 'name': name})
+ if datum[name] is None:
+ q.delete({'uuid': uid, 'name': name})
+ else:
+ q.update({'value': datum[name]},
+ {'uuid': uid, 'name': name})
else:
- q.insert((uid, name, datum[name]))
+ if datum[name] is not None:
+ q.insert((uid, name, datum[name]))
q.commit()
except Exception, e: # pylint: disable=broad-except
def __init__(self, path=None):
super(TranStore, self).__init__('transactions.db')
+
+
+class SAML2SessionStore(Store):
+
+ def __init__(self, database_url):
+ super(SAML2SessionStore, self).__init__(database_url=database_url)
+ self.table = 'sessions'
+ # pylint: disable=protected-access
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ table.create(checkfirst=True)
+
+ def _get_unique_id_from_column(self, name, value):
+ """
+ The query is going to return only the column in the query.
+ Use this method to get the uuidval which can be used to fetch
+ the entire entry.
+
+ Returns None or the uuid of the first value found.
+ """
+ data = self.get_unique_data(self.table, name=name, value=value)
+ count = len(data)
+ if count == 0:
+ return None
+ elif count != 1:
+ raise ValueError("Multiple entries returned")
+ return data.keys()[0]
+
+ def remove_expired_sessions(self):
+ # pylint: disable=protected-access
+ table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+ sel = select([table.columns.uuid]). \
+ where(and_(table.c.name == 'expiration_time',
+ table.c.value <= datetime.datetime.now()))
+ # pylint: disable=no-value-for-parameter
+ d = table.delete().where(table.c.uuid.in_(sel))
+ d.execute()
+
+ def get_data(self, idval=None, name=None, value=None):
+ return self.get_unique_data(self.table, idval, name, value)
+
+ def new_session(self, datum):
+ return self.new_unique_data(self.table, datum)
+
+ def get_session(self, session_id=None, request_id=None):
+ if session_id:
+ uuidval = self._get_unique_id_from_column('session_id', session_id)
+ elif request_id:
+ uuidval = self._get_unique_id_from_column('request_id', request_id)
+ else:
+ raise ValueError("Unable to find session")
+ if not uuidval:
+ return None, None
+ data = self.get_unique_data(self.table, uuidval=uuidval)
+ return uuidval, data[uuidval]
+
+ def get_user_sessions(self, user):
+ """
+ Retrun a list of all sessions for a given user.
+ """
+ rows = self.get_unique_data(self.table, name='user', value=user)
+
+ # We have a list of sessions for this user, now get the details
+ logged_in = []
+ for r in rows:
+ data = self.get_unique_data(self.table, uuidval=r)
+ logged_in.append(data)
+
+ return logged_in
+
+ def update_session(self, datum):
+ self.save_unique_data(self.table, datum)
+
+ def remove_session(self, uuidval):
+ self.del_unique_data(self.table, uuidval)
+
+ def wipe_data(self):
+ self._reset_data(self.table)