3 # Copyright (C) 2014 Simo Sorce <simo@redhat.com>
5 # see file 'COPYING' for use and warranty information
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 # GNU General Public License for more details.
17 # You should have received a copy of the GNU General Public License
18 # along with this program. If not, see <http://www.gnu.org/licenses/>.
26 from urllib import urlencode
29 class WrongPage(Exception):
33 class PageTree(object):
35 def __init__(self, result):
37 self.text = result.text
42 if self._tree is None:
43 self._tree = html.fromstring(self.text)
46 def first_value(self, rule):
47 result = self.tree.xpath(rule)
48 if type(result) is list:
55 def all_values(self, rule):
56 result = self.tree.xpath(rule)
57 if type(result) is list:
61 def make_referer(self):
62 return self.result.url
64 def expected_value(self, rule, expected):
65 value = self.first_value(rule)
67 raise ValueError("Expected [%s], got [%s]" % (expected, value))
70 class HttpSessions(object):
75 def add_server(self, name, baseuri, user=None, pwd=None):
76 new = {'baseuri': baseuri,
77 'session': requests.Session()}
82 self.servers[name] = new
84 def get_session(self, url):
85 for srv in self.servers:
87 if url.startswith(d['baseuri']):
90 raise ValueError("Unknown URL: %s" % url)
92 def get(self, url, **kwargs):
93 session = self.get_session(url)
94 return session.get(url, allow_redirects=False, **kwargs)
96 def post(self, url, **kwargs):
97 session = self.get_session(url)
98 return session.post(url, allow_redirects=False, **kwargs)
100 def access(self, action, url, **kwargs):
101 action = string.lower(action)
103 return self.get(url, **kwargs)
104 elif action == 'post':
105 return self.post(url, **kwargs)
107 raise ValueError("Unknown action type: [%s]" % action)
109 def new_url(self, referer, action):
110 if action.startswith('/'):
111 u = urlparse.urlparse(referer)
112 return '%s://%s%s' % (u.scheme, u.netloc, action)
115 def get_form_data(self, page, form_id, input_fields):
117 action = page.first_value('//form[@id="%s"]/@action' % form_id)
118 values.append(action)
119 method = page.first_value('//form[@id="%s"]/@method' % form_id)
120 values.append(method)
121 for field in input_fields:
122 value = page.all_values('//form[@id="%s"]/input/@%s' % (form_id,
127 def handle_login_form(self, idp, page):
128 if type(page) != PageTree:
129 raise TypeError("Expected PageTree object")
131 srv = self.servers[idp]
134 results = self.get_form_data(page, "login_form", ["name", "value"])
135 action_url = results[0]
139 if action_url is None:
141 except Exception: # pylint: disable=broad-except
142 raise WrongPage("Not a Login Form Page")
144 referer = page.make_referer()
145 headers = {'referer': referer}
147 for i in range(0, len(names)):
148 payload[names[i]] = values[i]
150 # replace known values
151 payload['login_name'] = srv['user']
152 payload['login_password'] = srv['pwd']
154 return [method, self.new_url(referer, action_url),
155 {'headers': headers, 'data': payload}]
157 def handle_return_form(self, page):
158 if type(page) != PageTree:
159 raise TypeError("Expected PageTree object")
162 results = self.get_form_data(page, "saml-response",
164 action_url = results[0]
165 if action_url is None:
170 except Exception: # pylint: disable=broad-except
171 raise WrongPage("Not a Return Form Page")
173 referer = page.make_referer()
174 headers = {'referer': referer}
177 for i in range(0, len(names)):
178 payload[names[i]] = values[i]
180 return [method, self.new_url(referer, action_url),
181 {'headers': headers, 'data': payload}]
183 def fetch_page(self, idp, target_url, follow_redirect=True):
189 r = self.access(action, url, **args) # pylint: disable=star-args
190 if r.status_code == 303:
191 if not follow_redirect:
193 url = r.headers['location']
196 elif r.status_code == 200:
200 (action, url, args) = self.handle_login_form(idp, page)
206 (action, url, args) = self.handle_return_form(page)
211 # Either we got what we wanted, or we have to stop anyway
214 raise ValueError("Unhandled status (%d) on url %s" % (
217 def auth_to_idp(self, idp):
219 srv = self.servers[idp]
220 target_url = '%s/%s/' % (srv['baseuri'], idp)
222 r = self.access('get', target_url)
223 if r.status_code != 200:
224 raise ValueError("Access to idp failed: %s" % repr(r))
227 page.expected_value('//div[@id="content"]/p/a/text()', 'Log In')
228 href = page.first_value('//div[@id="content"]/p/a/@href')
229 url = self.new_url(target_url, href)
231 page = self.fetch_page(idp, url)
232 page.expected_value('//div[@id="welcome"]/p/text()',
233 'Welcome %s!' % srv['user'])
235 def get_sp_metadata(self, idp, sp):
236 idpsrv = self.servers[idp]
237 idpuri = idpsrv['baseuri']
239 spuri = self.servers[sp]['baseuri']
241 return (idpuri, requests.get('%s/saml2/metadata' % spuri))
243 def add_sp_metadata(self, idp, sp, rest=False):
244 expected_status = 200
245 idpsrv = self.servers[idp]
246 (idpuri, m) = self.get_sp_metadata(idp, sp)
247 url = '%s/%s/admin/providers/saml2/admin/new' % (idpuri, idp)
248 headers = {'referer': url}
250 expected_status = 201
251 payload = {'metadata': m.content}
252 headers['content-type'] = 'application/x-www-form-urlencoded'
253 url = '%s/%s/rest/providers/saml2/SPS/%s' % (idpuri, idp, sp)
254 r = idpsrv['session'].post(url, headers=headers,
255 data=urlencode(payload))
257 metafile = {'metafile': m.content}
258 payload = {'name': sp}
259 r = idpsrv['session'].post(url, headers=headers,
260 data=payload, files=metafile)
261 if r.status_code != expected_status:
262 raise ValueError('Failed to post SP data [%s]' % repr(r))
266 page.expected_value('//div[@class="alert alert-success"]/p/text()',
267 'SP Successfully added')
269 def set_sp_default_nameids(self, idp, sp, nameids):
271 nameids is a list of Name ID formats to enable
273 idpsrv = self.servers[idp]
274 idpuri = idpsrv['baseuri']
275 url = '%s/%s/admin/providers/saml2/admin/sp/%s' % (idpuri, idp, sp)
276 headers = {'referer': url}
277 headers['content-type'] = 'application/x-www-form-urlencoded'
278 payload = {'submit': 'Submit',
279 'allowed_nameids': ', '.join(nameids)}
280 r = idpsrv['session'].post(url, headers=headers,
282 if r.status_code != 200:
283 raise ValueError('Failed to post SP data [%s]' % repr(r))
285 def fetch_rest_page(self, idpname, uri):
287 idpname - the name of the IDP to fetch the page from
288 uri - the URI of the page to retrieve
290 The URL for the request is built from known-information in
293 returns dict if successful
294 returns ValueError if the output is unparseable
296 baseurl = self.servers[idpname].get('baseuri')
297 page = self.fetch_page(
299 '%s%s' % (baseurl, uri)
301 return json.loads(page.text)
303 def get_rest_sp(self, idpname, spname=None):
305 uri = '/%s/rest/providers/saml2/SPS/' % idpname
307 uri = '/%s/rest/providers/saml2/SPS/%s' % (idpname, spname)
309 return self.fetch_rest_page(idpname, uri)