Add support for logout over SOAP
[cascardo/ipsilon.git] / ipsilon / util / data.py
index b7fde31..e0cd6e1 100644 (file)
@@ -1,11 +1,12 @@
 # Copyright (C) 2013 Ipsilon project Contributors, for license see COPYING
 
 import cherrypy
+import datetime
 from ipsilon.util.log import Log
 from sqlalchemy import create_engine
 from sqlalchemy import MetaData, Table, Column, Text
 from sqlalchemy.pool import QueuePool, SingletonThreadPool
-from sqlalchemy.sql import select
+from sqlalchemy.sql import select, and_
 import ConfigParser
 import os
 import uuid
@@ -509,3 +510,86 @@ class TranStore(Store):
 
     def __init__(self, path=None):
         super(TranStore, self).__init__('transactions.db')
+
+
+class SAML2SessionStore(Store):
+
+    def __init__(self, database_url):
+        super(SAML2SessionStore, self).__init__(database_url=database_url)
+        self.table = 'sessions'
+        # pylint: disable=protected-access
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        table.create(checkfirst=True)
+
+    def _get_unique_id_from_column(self, name, value):
+        """
+        The query is going to return only the column in the query.
+        Use this method to get the uuidval which can be used to fetch
+        the entire entry.
+
+        Returns None or the uuid of the first value found.
+        """
+        data = self.get_unique_data(self.table, name=name, value=value)
+        count = len(data)
+        if count == 0:
+            return None
+        elif count != 1:
+            raise ValueError("Multiple entries returned")
+        return data.keys()[0]
+
+    def remove_expired_sessions(self):
+        # pylint: disable=protected-access
+        table = SqlQuery(self._db, self.table, UNIQUE_DATA_COLUMNS)._table
+        sel = select([table.columns.uuid]). \
+            where(and_(table.c.name == 'expiration_time',
+                       table.c.value <= datetime.datetime.now()))
+        # pylint: disable=no-value-for-parameter
+        d = table.delete().where(table.c.uuid.in_(sel))
+        d.execute()
+
+    def get_data(self, idval=None, name=None, value=None):
+        return self.get_unique_data(self.table, idval, name, value)
+
+    def new_session(self, datum):
+        if 'supported_logout_mechs' in datum:
+            datum['supported_logout_mechs'] = ','.join(
+                datum['supported_logout_mechs']
+            )
+        return self.new_unique_data(self.table, datum)
+
+    def get_session(self, session_id=None, request_id=None):
+        if session_id:
+            uuidval = self._get_unique_id_from_column('session_id', session_id)
+        elif request_id:
+            uuidval = self._get_unique_id_from_column('request_id', request_id)
+        else:
+            raise ValueError("Unable to find session")
+        if not uuidval:
+            return None, None
+        data = self.get_unique_data(self.table, uuidval=uuidval)
+        return uuidval, data[uuidval]
+
+    def get_user_sessions(self, user):
+        """
+        Return a list of all sessions for a given user.
+        """
+        rows = self.get_unique_data(self.table, name='user', value=user)
+
+        # We have a list of sessions for this user, now get the details
+        logged_in = []
+        for r in rows:
+            data = self.get_unique_data(self.table, uuidval=r)
+            data[r]['supported_logout_mechs'] = data[r].get(
+                'supported_logout_mechs', '').split(',')
+            logged_in.append(data)
+
+        return logged_in
+
+    def update_session(self, datum):
+        self.save_unique_data(self.table, datum)
+
+    def remove_session(self, uuidval):
+        self.del_unique_data(self.table, uuidval)
+
+    def wipe_data(self):
+        self._reset_data(self.table)