From 965f8816c8637bd7441bafd3f2a606664a74e56c Mon Sep 17 00:00:00 2001 From: Lars Wirzenius Date: Fri, 6 Apr 2018 21:48:07 +0300 Subject: Add: AuthClient --- ick2/__init__.py | 1 + ick2/client.py | 88 +++++++++++++++++++++++++++++++++++++++++----- ick2/client_tests.py | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++-- pylint.conf | 1 + 4 files changed, 178 insertions(+), 10 deletions(-) diff --git a/ick2/__init__.py b/ick2/__init__.py index 8673046..1f6b514 100644 --- a/ick2/__init__.py +++ b/ick2/__init__.py @@ -57,6 +57,7 @@ from .client import ( HttpError, ControllerClient, BlobClient, + AuthClient, Reporter, ) from .actionenvs import ( diff --git a/ick2/client.py b/ick2/client.py index 8bcf45b..cda3649 100644 --- a/ick2/client.py +++ b/ick2/client.py @@ -13,9 +13,11 @@ # along with this program. If not, see . +import base64 import json import logging import time +import urllib import requests @@ -64,18 +66,32 @@ class HttpAPI: self._send_request(self._session.post, url, headers=headers, body=body) return None + def post_auth(self, url, headers=None, body=None, auth=None): + assert auth is not None + if headers is None: + headers = {} + headers['Authorization'] = self._basic_auth(auth) + return self._send_request( + self._session.post, url, headers=headers, body=body, auth=auth) + + def _basic_auth(self, auth): + username, password = auth + cleartext = '{}:{}'.format(username, password).encode('UTF-8') + encoded = base64.b64encode(cleartext) + return 'Basic {}'.format(encoded.decode('UTF-8')) + def put(self, url, headers=None, body=None): self._send_request(self._session.put, url, headers=headers, body=body) return None - def _send_request(self, func, url, headers=None, body=None): + def _send_request(self, func, url, headers=None, body=None, auth=None): if headers is None: headers = {} headers = dict(headers) - h, body = self._get_content_type_header(body) - headers.update(h) - self._request(func, url, headers=headers, data=body) - return None + if not headers.get('Content-Type'): + h, body = self._get_content_type_header(body) + headers.update(h) + return self._request(func, url, headers=headers, data=body, auth=auth) def _get_content_type_header(self, body): if isinstance(body, dict): @@ -94,7 +110,12 @@ class HttpAPI: def _request(self, func, url, headers=None, **kwargs): if headers is None: headers = {} - headers.update(self._get_authorization_headers()) + + auth = kwargs.get('auth') + if auth is None: + headers.update(self._get_authorization_headers()) + if 'auth' in kwargs: + del kwargs['auth'] r = func(url, headers=headers, verify=self._verify, **kwargs) if not r.ok: @@ -108,6 +129,7 @@ class ControllerClient: self._name = None self._api = HttpAPI() self._url = None + self._auth_url = None def set_client_name(self, name): self._name = name @@ -127,13 +149,29 @@ class ControllerClient: def url(self, path): return '{}{}'.format(self._url, path) - def get_artifact_store_url(self): + def get_version(self): url = self.url('/version') - version = self._api.get_dict(url) + return self._api.get_dict(url) + + def get_artifact_store_url(self): + version = self.version() url = version.get('artifact_store') logging.info('Artifact store URL: %r', url) return url + def get_auth_url(self): + version = self.get_version() + url = version.get('auth_url') + logging.info('Authentication URL: %r', url) + return url + + def get_auth_client(self): + url = self.get_auth_url() + ac = AuthClient() + ac.set_auth_url(url) + ac.set_http_api(self._api) + return ac + def get_blob_client(self): url = self.get_artifact_store_url() blobs = BlobClient() @@ -169,6 +207,40 @@ class ControllerClient: self._api.post(url, headers=headers, body=body) +class AuthClient: + + def __init__(self): + self._auth_url = None + self._http_api = HttpAPI() + self._client_id = None + self._client_secret = None + + def set_auth_url(self, url): + self._auth_url = url + + def set_http_api(self, api): + self._http_api = api + + def set_client_creds(self, client_id, client_secret): + self._client_id = client_id + self._client_secret = client_secret + + def get_token(self, scope): + auth = (self._client_id, self._client_secret) + params = { + 'grant_type': 'client_credentials', + 'scope': scope, + } + body = urllib.parse.urlencode(params) + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + } + r = self._http_api.post_auth( + self._auth_url, headers=headers, body=body, auth=auth) + obj = r.json() + return obj['access_token'] + + class Reporter: # pragma: no cover def __init__(self, api, work): diff --git a/ick2/client_tests.py b/ick2/client_tests.py index 92da36a..f164e6a 100644 --- a/ick2/client_tests.py +++ b/ick2/client_tests.py @@ -97,6 +97,30 @@ class HttpAPITests(unittest.TestCase): obj = self.client.put('http://controller/work', body=blob) self.assertEqual(obj, None) + def test_post_auth_does_basic_auth(self): + token = 'this is a token' + scope = 'this and that' + response = { + 'access_token': token, + 'token_type': 'bearer', + 'scope': scope, + } + self.session.response = FakeResponse(200, body=response) + + client_id = 'this-is--my-client' + client_secret = '*****' + auth = (client_id, client_secret) + body = 'foo=bar' + self.client.post_auth( + 'http://auth.example.com/token', auth=auth, body=body) + + self.assertEqual(self.session.auth, auth) + + authz = self.session.headers['Authorization'] + self.assertTrue(authz.startswith('Basic ')) + + self.assertEqual(self.session.body, body) + class ControllerClientTests(unittest.TestCase): @@ -164,6 +188,70 @@ class ControllerClientTests(unittest.TestCase): 200, body=json.dumps(version), content_type=json_type) self.assertEqual(self.controller.get_artifact_store_url(), url) + def test_get_auth_url_raises_exception_on_error(self): + self.session.response = FakeResponse(400) + with self.assertRaises(ick2.HttpError): + self.controller.get_auth_url() + + def test_get_auth_url_succeeds(self): + url = 'https://blobs' + version = { + 'auth_url': url, + } + self.session.response = FakeResponse( + 200, body=json.dumps(version), content_type=json_type) + self.assertEqual(self.controller.get_auth_url(), url) + + def test_get_auth_client_returns_object(self): + url = 'https://blobs' + version = { + 'auth_url': url, + } + self.session.response = FakeResponse( + 200, body=json.dumps(version), content_type=json_type) + ac = self.controller.get_auth_client() + self.assertTrue(isinstance(ac, ick2.AuthClient)) + + +class AuthClientTests(unittest.TestCase): + + def setUp(self): + self.session = FakeHttpSession() + + self.client = ick2.HttpAPI() + self.client.set_session(self.session) + + def test_raises_exception_on_error(self): + self.session.response = FakeResponse(400) + + url = 'https://auth.example.com' + client_id = 'test-client' + client_secret = 'hunter2' + ac = ick2.AuthClient() + ac.set_auth_url(url) + ac.set_http_api(self.client) + ac.set_client_creds(client_id, client_secret) + with self.assertRaises(ick2.HttpError): + ac.get_token('') + + def test_returns_token(self): + token = 'this-is-my-token' + token_response = { + 'access_token': token, + } + + self.session.response = FakeResponse( + 200, body=json.dumps(token_response), content_type=json_type) + + url = 'https://auth.example.com' + client_id = 'test-client' + client_secret = 'hunter2' + ac = ick2.AuthClient() + ac.set_auth_url(url) + ac.set_http_api(self.client) + ac.set_client_creds(client_id, client_secret) + self.assertEqual(ac.get_token(''), token) + class BlobServiceClientTests(unittest.TestCase): @@ -231,15 +319,21 @@ class FakeHttpSession: def __init__(self): self.response = None self.token = None + self.auth = None + self.headers = None + self.body = None def get(self, url, headers=None, verify=None): assert self.response is not None assert self.is_authorized(headers) return self.response - def post(self, url, headers=None, data=None, verify=None): + def post(self, url, headers=None, data=None, verify=None, auth=None): assert self.response is not None - assert self.is_authorized(headers) + assert auth is not None or self.is_authorized(headers) + self.auth = auth + self.headers = headers + self.body = data return self.response def put(self, url, headers=None, data=None, verify=None): diff --git a/pylint.conf b/pylint.conf index 8ac3b99..ed09e29 100644 --- a/pylint.conf +++ b/pylint.conf @@ -10,6 +10,7 @@ disable= no-self-use, not-callable, too-few-public-methods, + too-many-arguments, too-many-public-methods, unused-argument, unused-variable -- cgit v1.2.1