c9cfd9c4f7944c4a101348fcaf6fbcb46cdd7781
[cascardo/ipsilon.git] / ipsilon / providers / saml2 / sessions.py
1 # Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING
2
3 from ipsilon.util.log import Log
4
5
6 class SAMLSession(Log):
7     """
8     A SAML login session used to track login/logout state.
9
10        session_id - ID of the login session
11        provider_id - ID of the SP
12        session - the Login session object
13        logoutstate - dict containing logout state info
14        session_indexes - the IDs of any login session we've seen
15                          for this user
16
17     When a new session is seen for the same user any existing session
18     is thrown away. We keep the original session_id though and send
19     all that we've seen to the SP when performing a logout to ensure
20     that all sessions get logged out.
21
22     logout state is a dictionary containing (potentially)
23     these attributes:
24
25     relaystate - The relaystate from the Logout Request or Response
26     id         - The Logout request id that initiated the logout
27     request    - Dump of the initial logout request
28     """
29     def __init__(self, session_id, provider_id, session,
30                  logoutstate=None):
31
32         self.session_id = session_id
33         self.provider_id = provider_id
34         self.session = session
35         self.logoutstate = logoutstate
36         self.session_indexes = [session_id]
37
38     def set_logoutstate(self, relaystate, request_id, request=None):
39         self.logoutstate = dict(relaystate=relaystate,
40                                 id=request_id,
41                                 request=request)
42
43     def dump(self):
44         self.debug('session_id %s' % self.session_id)
45         self.debug('session_index %s' % self.session_indexes)
46         self.debug('provider_id %s' % self.provider_id)
47         self.debug('session %s' % self.session)
48         self.debug('logoutstate %s' % self.logoutstate)
49
50
51 class SAMLSessionsContainer(Log):
52     """
53     Store SAML session information.
54
55     The sessions are stored in two dicts which represent the state that
56     the session is in.
57
58     When a user logs in, add_session() is called and a new SAMLSession
59     created and added to the sessions dict, keyed on provider_id.
60
61     When a user logs out, the next login session is found and moved to
62     sessions_logging_out. remove_session() will look in both when trying
63     to remove a session.
64     """
65
66     def __init__(self):
67         self.sessions = dict()
68         self.sessions_logging_out = dict()
69
70     def add_session(self, session_id, provider_id, session):
71         """
72         Add a new session to the logged-in bucket.
73
74         Drop any existing sessions that might exist for this
75         provider. We have no control over the SP's so if it sends
76         us another login, accept it.
77
78         If an existing session exists drop it but keep a copy of
79         its session index. When we logout we send ALL session indexes
80         we've received to ensure that they are all logged out.
81         """
82         samlsession = SAMLSession(session_id, provider_id, session)
83
84         old_session = self.find_session_by_provider(provider_id)
85         if old_session is not None:
86             samlsession.session_indexes.extend(old_session.session_indexes)
87             self.debug("old session: %s" % old_session.session_indexes)
88             self.debug("new session: %s" % samlsession.session_indexes)
89             self.remove_session_by_provider(provider_id)
90         self.sessions[provider_id] = samlsession
91         self.dump()
92
93     def remove_session_by_provider(self, provider_id):
94         """
95         Remove all instances of this provider from either session
96         pool.
97         """
98         if provider_id in self.sessions:
99             self.sessions.pop(provider_id)
100         if provider_id in self.sessions_logging_out:
101             self.sessions_logging_out.pop(provider_id)
102
103     def find_session_by_provider(self, provider_id):
104         """
105         Return a given session from either pool.
106
107         Return None if no session for a provider is found.
108         """
109         if provider_id in self.sessions:
110             return self.sessions[provider_id]
111         if provider_id in self.sessions_logging_out:
112             return self.sessions_logging_out[provider_id]
113         return None
114
115     def start_logout(self, session):
116         """
117         Move a session into the logging_out state
118
119         No return value
120         """
121         if session.provider_id in self.sessions_logging_out:
122             return
123
124         session = self.sessions.pop(session.provider_id)
125
126         self.sessions_logging_out[session.provider_id] = session
127
128     def get_next_logout(self, remove=True):
129         """
130         Get the next session in the logged-in state and move
131         it to the logging_out state.  Return the session that is
132         found.
133
134         :param remove: for IdP-initiated logout we can't remove the
135                        session otherwise when the request comes back
136                        in the user won't be seen as being logged-on.
137
138         Return None if no more sessions in login state.
139         """
140         try:
141             provider_id = self.sessions.keys()[0]
142         except IndexError:
143             return None
144
145         if remove:
146             session = self.sessions.pop(provider_id)
147         else:
148             session = self.sessions.itervalues().next()
149
150         if provider_id in self.sessions_logging_out:
151             self.sessions_logging_out.pop(provider_id)
152
153         self.sessions_logging_out[provider_id] = session
154
155         return session
156
157     def get_last_session(self):
158         if self.count() != 1:
159             raise ValueError('Not exactly one session left')
160
161         try:
162             provider_id = self.sessions_logging_out.keys()[0]
163         except IndexError:
164             return None
165
166         return self.sessions_logging_out.pop(provider_id)
167
168     def count(self):
169         """
170         Return number of active login/logging out sessions.
171         """
172         return len(self.sessions) + len(self.sessions_logging_out)
173
174     def dump(self):
175         count = 0
176         for s in self.sessions:
177             self.debug('Login Session: %d' % count)
178             session = self.sessions[s]
179             session.dump()
180             self.debug('-----------------------')
181             count += 1
182         for s in self.sessions_logging_out:
183             self.debug('Logging-out Session: %d' % count)
184             session = self.sessions_logging_out[s]
185             session.dump()
186             self.debug('-----------------------')
187             count += 1
188
189 if __name__ == '__main__':
190     provider1 = "http://127.0.0.10/saml2"
191     provider2 = "http://127.0.0.11/saml2"
192
193     saml_sessions = SAMLSessionsContainer()
194
195     try:
196         testsession = saml_sessions.get_last_session()
197     except ValueError:
198         assert(saml_sessions.count() == 0)
199
200     saml_sessions.add_session("_123456",
201                               provider1,
202                               "sessiondata")
203
204     saml_sessions.add_session("_789012",
205                               provider2,
206                               "sessiondata")
207
208     try:
209         testsession = saml_sessions.get_last_session()
210     except ValueError:
211         assert(saml_sessions.count() == 2)
212
213     testsession = saml_sessions.find_session_by_provider(provider1)
214     assert(testsession.provider_id == provider1)
215     assert(testsession.session_id == "_123456")
216     assert(testsession.session == "sessiondata")
217
218     # Test get_next_logout() by fetching both values out. Do some
219     # basic accounting to ensure we get both values eventually.
220     providers = [provider1, provider2]
221     testsession = saml_sessions.get_next_logout()
222     providers.remove(testsession.provider_id)  # should be one of them
223
224     testsession = saml_sessions.get_next_logout()
225     assert(testsession.provider_id == providers[0])  # should be the other
226
227     saml_sessions.start_logout(testsession)
228     saml_sessions.remove_session_by_provider(provider2)
229
230     assert(saml_sessions.count() == 1)
231
232     testsession = saml_sessions.get_last_session()
233     assert(testsession.provider_id == provider1)
234
235     saml_sessions.remove_session_by_provider(provider1)
236     assert(saml_sessions.count() == 0)