1 # Copyright (C) 2015 Rob Crittenden <rcritten@redhat.com>
3 # see file 'COPYING' for use and warranty information
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 from ipsilon.util.log import Log
21 class SAMLSession(Log):
23 A SAML login session used to track login/logout state.
25 session_id - ID of the login session
26 provider_id - ID of the SP
27 session - the Login session object
28 logoutstate - dict containing logout state info
30 logout state is a dictionary containing (potentially)
33 relaystate - The relaystate from the Logout Request or Response
34 id - The Logout request id that initiated the logout
35 request - Dump of the initial logout request
37 def __init__(self, session_id, provider_id, session,
40 self.session_id = session_id
41 self.provider_id = provider_id
42 self.session = session
43 self.logoutstate = logoutstate
45 def set_logoutstate(self, relaystate, request_id, request=None):
46 self.logoutstate = dict(relaystate=relaystate,
51 self.debug('session_id %s' % self.session_id)
52 self.debug('provider_id %s' % self.provider_id)
53 self.debug('session %s' % self.session)
54 self.debug('logoutstate %s' % self.logoutstate)
57 class SAMLSessionsContainer(Log):
59 Store SAML session information.
61 The sessions are stored in two dicts which represent the state that
64 When a user logs in, add_session() is called and a new SAMLSession
65 created and added to the sessions dict, keyed on provider_id.
67 When a user logs out, the next login session is found and moved to
68 sessions_logging_out. remove_session() will look in both when trying
73 self.sessions = dict()
74 self.sessions_logging_out = dict()
76 def add_session(self, session_id, provider_id, session):
78 Add a new session to the logged-in bucket.
80 Drop any existing sessions that might exist for this
81 provider. We have no control over the SP's so if it sends
82 us another login, accept it.
84 samlsession = SAMLSession(session_id, provider_id, session)
86 self.remove_session_by_provider(provider_id)
87 self.sessions[provider_id] = samlsession
90 def remove_session_by_provider(self, provider_id):
92 Remove all instances of this provider from either session
95 if provider_id in self.sessions:
96 self.sessions.pop(provider_id)
97 if provider_id in self.sessions_logging_out:
98 self.sessions_logging_out.pop(provider_id)
100 def find_session_by_provider(self, provider_id):
102 Return a given session from either pool.
104 Return None if no session for a provider is found.
106 if provider_id in self.sessions:
107 return self.sessions[provider_id]
108 if provider_id in self.sessions_logging_out:
109 return self.sessions_logging_out[provider_id]
112 def start_logout(self, session):
114 Move a session into the logging_out state
118 if session.provider_id in self.sessions_logging_out:
121 session = self.sessions.pop(session.provider_id)
123 self.sessions_logging_out[session.provider_id] = session
125 def get_next_logout(self):
127 Get the next session in the logged-in state and move
128 it to the logging_out state. Return the session that is
131 Return None if no more sessions in login state.
134 provider_id = self.sessions.keys()[0]
138 session = self.sessions.pop(provider_id)
140 if provider_id in self.sessions_logging_out:
141 self.sessions_logging_out.pop(provider_id)
143 self.sessions_logging_out[provider_id] = session
147 def get_last_session(self):
148 if self.count() != 1:
149 raise ValueError('Not exactly one session left')
152 provider_id = self.sessions_logging_out.keys()[0]
156 return self.sessions_logging_out.pop(provider_id)
160 Return number of active login/logging out sessions.
162 return len(self.sessions) + len(self.sessions_logging_out)
166 for s in self.sessions:
167 self.debug('Login Session: %d' % count)
168 session = self.sessions[s]
170 self.debug('-----------------------')
172 for s in self.sessions_logging_out:
173 self.debug('Logging-out Session: %d' % count)
174 session = self.sessions_logging_out[s]
176 self.debug('-----------------------')
179 if __name__ == '__main__':
180 provider1 = "http://127.0.0.10/saml2"
181 provider2 = "http://127.0.0.11/saml2"
183 saml_sessions = SAMLSessionsContainer()
186 testsession = saml_sessions.get_last_session()
188 assert(saml_sessions.count() == 0)
190 saml_sessions.add_session("_123456",
194 saml_sessions.add_session("_789012",
199 testsession = saml_sessions.get_last_session()
201 assert(saml_sessions.count() == 2)
203 testsession = saml_sessions.find_session_by_provider(provider1)
204 assert(testsession.provider_id == provider1)
205 assert(testsession.session_id == "_123456")
206 assert(testsession.session == "sessiondata")
208 # Test get_next_logout() by fetching both values out. Do some
209 # basic accounting to ensure we get both values eventually.
210 providers = [provider1, provider2]
211 testsession = saml_sessions.get_next_logout()
212 providers.remove(testsession.provider_id) # should be one of them
214 testsession = saml_sessions.get_next_logout()
215 assert(testsession.provider_id == providers[0]) # should be the other
217 saml_sessions.start_logout(testsession)
218 saml_sessions.remove_session_by_provider(provider2)
220 assert(saml_sessions.count() == 1)
222 testsession = saml_sessions.get_last_session()
223 assert(testsession.provider_id == provider1)
225 saml_sessions.remove_session_by_provider(provider1)
226 assert(saml_sessions.count() == 0)