summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLars Wirzenius <liw@liw.fi>2018-04-06 21:48:07 +0300
committerLars Wirzenius <liw@liw.fi>2018-04-07 16:53:19 +0300
commit965f8816c8637bd7441bafd3f2a606664a74e56c (patch)
tree7304a230ef23105a2f3b9932a130df517fb02eae
parent2d0d514d1f86e58d965160c28a4954da33b10baf (diff)
downloadick2-965f8816c8637bd7441bafd3f2a606664a74e56c.tar.gz
Add: AuthClient
-rw-r--r--ick2/__init__.py1
-rw-r--r--ick2/client.py88
-rw-r--r--ick2/client_tests.py98
-rw-r--r--pylint.conf1
4 files changed, 178 insertions, 10 deletions
diff --git a/ick2/__init__.py b/ick2/__init__.py
index 8673046..1f6b514 100644
--- a/ick2/__init__.py
+++ b/ick2/__init__.py
@@ -57,6 +57,7 @@ from .client import (
HttpError,
ControllerClient,
BlobClient,
+ AuthClient,
Reporter,
)
from .actionenvs import (
diff --git a/ick2/client.py b/ick2/client.py
index 8bcf45b..cda3649 100644
--- a/ick2/client.py
+++ b/ick2/client.py
@@ -13,9 +13,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+import base64
import json
import logging
import time
+import urllib
import requests
@@ -64,18 +66,32 @@ class HttpAPI:
self._send_request(self._session.post, url, headers=headers, body=body)
return None
+ def post_auth(self, url, headers=None, body=None, auth=None):
+ assert auth is not None
+ if headers is None:
+ headers = {}
+ headers['Authorization'] = self._basic_auth(auth)
+ return self._send_request(
+ self._session.post, url, headers=headers, body=body, auth=auth)
+
+ def _basic_auth(self, auth):
+ username, password = auth
+ cleartext = '{}:{}'.format(username, password).encode('UTF-8')
+ encoded = base64.b64encode(cleartext)
+ return 'Basic {}'.format(encoded.decode('UTF-8'))
+
def put(self, url, headers=None, body=None):
self._send_request(self._session.put, url, headers=headers, body=body)
return None
- def _send_request(self, func, url, headers=None, body=None):
+ def _send_request(self, func, url, headers=None, body=None, auth=None):
if headers is None:
headers = {}
headers = dict(headers)
- h, body = self._get_content_type_header(body)
- headers.update(h)
- self._request(func, url, headers=headers, data=body)
- return None
+ if not headers.get('Content-Type'):
+ h, body = self._get_content_type_header(body)
+ headers.update(h)
+ return self._request(func, url, headers=headers, data=body, auth=auth)
def _get_content_type_header(self, body):
if isinstance(body, dict):
@@ -94,7 +110,12 @@ class HttpAPI:
def _request(self, func, url, headers=None, **kwargs):
if headers is None:
headers = {}
- headers.update(self._get_authorization_headers())
+
+ auth = kwargs.get('auth')
+ if auth is None:
+ headers.update(self._get_authorization_headers())
+ if 'auth' in kwargs:
+ del kwargs['auth']
r = func(url, headers=headers, verify=self._verify, **kwargs)
if not r.ok:
@@ -108,6 +129,7 @@ class ControllerClient:
self._name = None
self._api = HttpAPI()
self._url = None
+ self._auth_url = None
def set_client_name(self, name):
self._name = name
@@ -127,13 +149,29 @@ class ControllerClient:
def url(self, path):
return '{}{}'.format(self._url, path)
- def get_artifact_store_url(self):
+ def get_version(self):
url = self.url('/version')
- version = self._api.get_dict(url)
+ return self._api.get_dict(url)
+
+ def get_artifact_store_url(self):
+ version = self.version()
url = version.get('artifact_store')
logging.info('Artifact store URL: %r', url)
return url
+ def get_auth_url(self):
+ version = self.get_version()
+ url = version.get('auth_url')
+ logging.info('Authentication URL: %r', url)
+ return url
+
+ def get_auth_client(self):
+ url = self.get_auth_url()
+ ac = AuthClient()
+ ac.set_auth_url(url)
+ ac.set_http_api(self._api)
+ return ac
+
def get_blob_client(self):
url = self.get_artifact_store_url()
blobs = BlobClient()
@@ -169,6 +207,40 @@ class ControllerClient:
self._api.post(url, headers=headers, body=body)
+class AuthClient:
+
+ def __init__(self):
+ self._auth_url = None
+ self._http_api = HttpAPI()
+ self._client_id = None
+ self._client_secret = None
+
+ def set_auth_url(self, url):
+ self._auth_url = url
+
+ def set_http_api(self, api):
+ self._http_api = api
+
+ def set_client_creds(self, client_id, client_secret):
+ self._client_id = client_id
+ self._client_secret = client_secret
+
+ def get_token(self, scope):
+ auth = (self._client_id, self._client_secret)
+ params = {
+ 'grant_type': 'client_credentials',
+ 'scope': scope,
+ }
+ body = urllib.parse.urlencode(params)
+ headers = {
+ 'Content-Type': 'application/x-www-form-urlencoded',
+ }
+ r = self._http_api.post_auth(
+ self._auth_url, headers=headers, body=body, auth=auth)
+ obj = r.json()
+ return obj['access_token']
+
+
class Reporter: # pragma: no cover
def __init__(self, api, work):
diff --git a/ick2/client_tests.py b/ick2/client_tests.py
index 92da36a..f164e6a 100644
--- a/ick2/client_tests.py
+++ b/ick2/client_tests.py
@@ -97,6 +97,30 @@ class HttpAPITests(unittest.TestCase):
obj = self.client.put('http://controller/work', body=blob)
self.assertEqual(obj, None)
+ def test_post_auth_does_basic_auth(self):
+ token = 'this is a token'
+ scope = 'this and that'
+ response = {
+ 'access_token': token,
+ 'token_type': 'bearer',
+ 'scope': scope,
+ }
+ self.session.response = FakeResponse(200, body=response)
+
+ client_id = 'this-is--my-client'
+ client_secret = '*****'
+ auth = (client_id, client_secret)
+ body = 'foo=bar'
+ self.client.post_auth(
+ 'http://auth.example.com/token', auth=auth, body=body)
+
+ self.assertEqual(self.session.auth, auth)
+
+ authz = self.session.headers['Authorization']
+ self.assertTrue(authz.startswith('Basic '))
+
+ self.assertEqual(self.session.body, body)
+
class ControllerClientTests(unittest.TestCase):
@@ -164,6 +188,70 @@ class ControllerClientTests(unittest.TestCase):
200, body=json.dumps(version), content_type=json_type)
self.assertEqual(self.controller.get_artifact_store_url(), url)
+ def test_get_auth_url_raises_exception_on_error(self):
+ self.session.response = FakeResponse(400)
+ with self.assertRaises(ick2.HttpError):
+ self.controller.get_auth_url()
+
+ def test_get_auth_url_succeeds(self):
+ url = 'https://blobs'
+ version = {
+ 'auth_url': url,
+ }
+ self.session.response = FakeResponse(
+ 200, body=json.dumps(version), content_type=json_type)
+ self.assertEqual(self.controller.get_auth_url(), url)
+
+ def test_get_auth_client_returns_object(self):
+ url = 'https://blobs'
+ version = {
+ 'auth_url': url,
+ }
+ self.session.response = FakeResponse(
+ 200, body=json.dumps(version), content_type=json_type)
+ ac = self.controller.get_auth_client()
+ self.assertTrue(isinstance(ac, ick2.AuthClient))
+
+
+class AuthClientTests(unittest.TestCase):
+
+ def setUp(self):
+ self.session = FakeHttpSession()
+
+ self.client = ick2.HttpAPI()
+ self.client.set_session(self.session)
+
+ def test_raises_exception_on_error(self):
+ self.session.response = FakeResponse(400)
+
+ url = 'https://auth.example.com'
+ client_id = 'test-client'
+ client_secret = 'hunter2'
+ ac = ick2.AuthClient()
+ ac.set_auth_url(url)
+ ac.set_http_api(self.client)
+ ac.set_client_creds(client_id, client_secret)
+ with self.assertRaises(ick2.HttpError):
+ ac.get_token('')
+
+ def test_returns_token(self):
+ token = 'this-is-my-token'
+ token_response = {
+ 'access_token': token,
+ }
+
+ self.session.response = FakeResponse(
+ 200, body=json.dumps(token_response), content_type=json_type)
+
+ url = 'https://auth.example.com'
+ client_id = 'test-client'
+ client_secret = 'hunter2'
+ ac = ick2.AuthClient()
+ ac.set_auth_url(url)
+ ac.set_http_api(self.client)
+ ac.set_client_creds(client_id, client_secret)
+ self.assertEqual(ac.get_token(''), token)
+
class BlobServiceClientTests(unittest.TestCase):
@@ -231,15 +319,21 @@ class FakeHttpSession:
def __init__(self):
self.response = None
self.token = None
+ self.auth = None
+ self.headers = None
+ self.body = None
def get(self, url, headers=None, verify=None):
assert self.response is not None
assert self.is_authorized(headers)
return self.response
- def post(self, url, headers=None, data=None, verify=None):
+ def post(self, url, headers=None, data=None, verify=None, auth=None):
assert self.response is not None
- assert self.is_authorized(headers)
+ assert auth is not None or self.is_authorized(headers)
+ self.auth = auth
+ self.headers = headers
+ self.body = data
return self.response
def put(self, url, headers=None, data=None, verify=None):
diff --git a/pylint.conf b/pylint.conf
index 8ac3b99..ed09e29 100644
--- a/pylint.conf
+++ b/pylint.conf
@@ -10,6 +10,7 @@ disable=
no-self-use,
not-callable,
too-few-public-methods,
+ too-many-arguments,
too-many-public-methods,
unused-argument,
unused-variable