3 # Copyright (C) 2014 Simo Sorce <simo@redhat.com>
5 # see file 'COPYING' for use and warranty information
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 # GNU General Public License for more details.
17 # You should have received a copy of the GNU General Public License
18 # along with this program. If not, see <http://www.gnu.org/licenses/>.
26 from urllib import urlencode
29 class WrongPage(Exception):
33 class PageTree(object):
35 def __init__(self, result):
37 self.text = result.text
42 if self._tree is None:
43 self._tree = html.fromstring(self.text)
46 def first_value(self, rule):
47 result = self.tree.xpath(rule)
48 if type(result) is list:
55 def all_values(self, rule):
56 result = self.tree.xpath(rule)
57 if type(result) is list:
61 def make_referer(self):
62 return self.result.url
64 def expected_value(self, rule, expected):
65 value = self.first_value(rule)
67 raise ValueError("Expected [%s], got [%s]" % (expected, value))
70 class HttpSessions(object):
75 def add_server(self, name, baseuri, user=None, pwd=None):
76 new = {'baseuri': baseuri,
77 'session': requests.Session()}
82 self.servers[name] = new
84 def get_session(self, url):
85 for srv in self.servers:
87 if url.startswith(d['baseuri']):
90 raise ValueError("Unknown URL: %s" % url)
92 def get(self, url, **kwargs):
93 session = self.get_session(url)
94 return session.get(url, allow_redirects=False, **kwargs)
96 def post(self, url, **kwargs):
97 session = self.get_session(url)
98 return session.post(url, allow_redirects=False, **kwargs)
100 def access(self, action, url, **kwargs):
101 action = string.lower(action)
103 return self.get(url, **kwargs)
104 elif action == 'post':
105 return self.post(url, **kwargs)
107 raise ValueError("Unknown action type: [%s]" % action)
109 def new_url(self, referer, action):
110 if action.startswith('/'):
111 u = urlparse.urlparse(referer)
112 return '%s://%s%s' % (u.scheme, u.netloc, action)
115 def get_form_data(self, page, form_id, input_fields):
117 action = page.first_value('//form[@id="%s"]/@action' % form_id)
118 values.append(action)
119 method = page.first_value('//form[@id="%s"]/@method' % form_id)
120 values.append(method)
121 for field in input_fields:
122 value = page.all_values('//form[@id="%s"]/input/@%s' % (form_id,
127 def handle_login_form(self, idp, page):
128 if type(page) != PageTree:
129 raise TypeError("Expected PageTree object")
131 srv = self.servers[idp]
134 results = self.get_form_data(page, "login_form", ["name", "value"])
135 action_url = results[0]
139 if action_url is None:
141 except Exception: # pylint: disable=broad-except
142 raise WrongPage("Not a Login Form Page")
144 referer = page.make_referer()
145 headers = {'referer': referer}
147 for i in range(0, len(names)):
148 payload[names[i]] = values[i]
150 # replace known values
151 payload['login_name'] = srv['user']
152 payload['login_password'] = srv['pwd']
154 return [method, self.new_url(referer, action_url),
155 {'headers': headers, 'data': payload}]
157 def handle_return_form(self, page):
158 if type(page) != PageTree:
159 raise TypeError("Expected PageTree object")
162 results = self.get_form_data(page, "saml-response",
164 action_url = results[0]
165 if action_url is None:
170 except Exception: # pylint: disable=broad-except
171 raise WrongPage("Not a Return Form Page")
173 referer = page.make_referer()
174 headers = {'referer': referer}
177 for i in range(0, len(names)):
178 payload[names[i]] = values[i]
180 return [method, self.new_url(referer, action_url),
181 {'headers': headers, 'data': payload}]
183 def fetch_page(self, idp, target_url, follow_redirect=True):
189 r = self.access(action, url, **args) # pylint: disable=star-args
190 if r.status_code == 303:
191 if not follow_redirect:
193 url = r.headers['location']
196 elif r.status_code == 200:
200 (action, url, args) = self.handle_login_form(idp, page)
206 (action, url, args) = self.handle_return_form(page)
211 # Either we got what we wanted, or we have to stop anyway
214 raise ValueError("Unhandled status (%d) on url %s" % (
217 def auth_to_idp(self, idp):
219 srv = self.servers[idp]
220 target_url = '%s/%s/' % (srv['baseuri'], idp)
222 r = self.access('get', target_url)
223 if r.status_code != 200:
224 raise ValueError("Access to idp failed: %s" % repr(r))
227 page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
228 href = page.first_value('//div[@id="content"]/p/a/@href')
229 url = self.new_url(target_url, href)
231 page = self.fetch_page(idp, url)
232 page.expected_value('//div[@id="welcome"]/p/text()',
233 'Welcome %s!' % srv['user'])
235 def get_sp_metadata(self, idp, sp):
236 idpsrv = self.servers[idp]
237 idpuri = idpsrv['baseuri']
239 spuri = self.servers[sp]['baseuri']
241 return (idpuri, requests.get('%s/saml2/metadata' % spuri))
243 def add_sp_metadata(self, idp, sp, rest=False):
244 expected_status = 200
245 idpsrv = self.servers[idp]
246 (idpuri, m) = self.get_sp_metadata(idp, sp)
247 url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
248 headers = {'referer': url}
250 expected_status = 201
251 payload = {'metadata': m.content}
252 headers['content-type'] = 'application/x-www-form-urlencoded'
253 url = '%s/%s/rest/providers/saml2/SPS/%s' % (idpuri, idp, sp)
254 r = idpsrv['session'].post(url, headers=headers,
255 data=urlencode(payload))
257 metafile = {'metafile': m.content}
258 payload = {'name': sp}
259 r = idpsrv['session'].post(url, headers=headers,
260 data=payload, files=metafile)
261 if r.status_code != expected_status:
262 raise ValueError('Failed to post SP data [%s]' % repr(r))
266 page.expected_value('//div[@class="alert alert-success"]/p/text()',
267 'SP Successfully added')
269 def set_sp_default_nameids(self, idp, sp, nameids):
271 nameids is a list of Name ID formats to enable
273 idpsrv = self.servers[idp]
274 idpuri = idpsrv['baseuri']
275 url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (idpuri, idp, sp)
276 headers = {'referer': url}
277 headers['content-type'] = 'application/x-www-form-urlencoded'
278 payload = {'submit': 'Submit',
279 'allowed_nameids': ', '.join(nameids)}
280 r = idpsrv['session'].post(url, headers=headers,
282 if r.status_code != 200:
283 raise ValueError('Failed to post SP data [%s]' % repr(r))
285 # pylint: disable=dangerous-default-value
286 def set_attributes_and_mapping(self, idp, mapping=[], attrs=[],
289 Set allowed attributes and mapping in the IDP or the SP. In the
290 case of the SP both allowed attributes and the mapping need to
291 be provided. An empty option for either means delete all values.
293 mapping is a list of list of rules of the form:
294 [['from-1', 'to-1'], ['from-2', 'from-2']]
296 ex. [['*', '*'], ['fullname', 'namefull']]
298 attrs is the list of attributes that will be allowed:
299 ['fullname', 'givenname', 'surname']
301 idpsrv = self.servers[idp]
302 idpuri = idpsrv['baseuri']
303 if spname: # per-SP setting
304 url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (
306 mapname = 'Attribute Mapping'
307 attrname = 'Allowed Attributes'
308 else: # global default
309 url = '%s/%s/admin/providers/saml2' % (idpuri, idp)
310 mapname = 'default attribute mapping'
311 attrname = 'default allowed attributes'
313 headers = {'referer': url}
314 headers['content-type'] = 'application/x-www-form-urlencoded'
315 payload = {'submit': 'Submit'}
318 payload['%s %s-from' % (mapname, count)] = m[0]
319 payload['%s %s-to' % (mapname, count)] = m[1]
323 payload['%s %s-name' % (attrname, count)] = attr
325 r = idpsrv['session'].post(url, headers=headers,
327 if r.status_code != 200:
328 raise ValueError('Failed to post IDP data [%s]' % repr(r))
330 def fetch_rest_page(self, idpname, uri):
332 idpname - the name of the IDP to fetch the page from
333 uri - the URI of the page to retrieve
335 The URL for the request is built from known-information in
338 returns dict if successful
339 returns ValueError if the output is unparseable
341 baseurl = self.servers[idpname].get('baseuri')
342 page = self.fetch_page(
344 '%s%s' % (baseurl, uri)
346 return json.loads(page.text)
348 def get_rest_sp(self, idpname, spname=None):
350 uri = '/%s/rest/providers/saml2/SPS/' % idpname
352 uri = '/%s/rest/providers/saml2/SPS/%s' % (idpname, spname)
354 return self.fetch_rest_page(idpname, uri)