summaryrefslogtreecommitdiff
path: root/yarns/lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'yarns/lib.py')
-rw-r--r--yarns/lib.py152
1 files changed, 110 insertions, 42 deletions
diff --git a/yarns/lib.py b/yarns/lib.py
index f3ac9d9..6d8f2cf 100644
--- a/yarns/lib.py
+++ b/yarns/lib.py
@@ -1,4 +1,4 @@
-# Copyright 2017-2018 Lars Wirzenius
+# Copyright 2017-2019 Lars Wirzenius
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -19,53 +19,100 @@ import errno
import json
import os
import random
+import re
+import signal
import socket
import sys
import time
import urllib
+import uuid
import cliapp
import requests
+import yaml
from yarnutils import *
srcdir = os.environ['SRCDIR']
datadir = os.environ['DATADIR']
-vars = Variables(datadir)
-
-
-def random_free_port():
- MAX = 1000
- for i in range(MAX):
- port = random.randint(1025, 2**15-1)
- s = socket.socket()
- try:
- s.bind(('0.0.0.0', port))
- except OSError as e:
- if e.errno == errno.EADDRINUSE:
- continue
- print('cannot find a random free port')
- raise
- s.close()
- break
- print('picked port', port)
- return port
-
-
-def wait_for_port(port):
- MAX = 5
- t = time.time()
- while time.time() < t + MAX:
- try:
- s = socket.socket()
- s.connect(('127.0.0.1', port))
- except socket.error:
- time.sleep(0.1)
- except OSError as e:
- raise
- else:
- return
+V = Variables(datadir)
+
+
+def remember_client_id(alias, client_id, client_secret):
+ clients = V['clients']
+ if clients is None:
+ clients = {}
+ clients[alias] = {
+ 'client_id': client_id,
+ 'client_secret': client_secret,
+ }
+ V['clients'] = clients
+
+
+def get_client_id(alias):
+ clients = V['clients'] or {}
+ return clients[alias]['client_id']
+
+
+def get_client_ids():
+ clients = V['clients'] or {}
+ return [x['client_id'] for x in clients.values()]
+
+
+def get_client_secret(alias):
+ clients = V['clients'] or {}
+ return clients[alias]['client_secret']
+
+
+def create_api_client(alias, scopes):
+ client_id = str(uuid.uuid4())
+ client_secret = str(uuid.uuid4())
+ print('invented client id', client_id)
+ api = os.environ['CONTROLLER']
+ print('controller URL', api)
+ secrets = os.environ['SECRETS']
+ print('secrets', secrets)
+ base_argv = ['qvisqvetool', '--secrets', secrets, '-a', api]
+ print('base_argv', base_argv)
+ cliapp.runcmd(base_argv + ['create', 'client', client_id, client_secret])
+ cliapp.runcmd(base_argv + ['allow-scope', 'client', client_id] + scopes)
+ remember_client_id(alias, client_id, client_secret)
+
+
+def delete_api_client(client_id):
+ api = os.environ['CONTROLLER']
+ secrets = os.environ['SECRETS']
+ base_argv = ['qvisqvetool', '--secrets', secrets, '-a', api]
+ cliapp.runcmd(base_argv + ['delete', 'client', client_id])
+
+
+def get_api_token(alias, scopes):
+ print('getting token for', alias)
+
+ client_id = get_client_id(alias)
+ client_secret = get_client_secret(alias)
+ api = os.environ['CONTROLLER']
+
+ auth = (client_id, client_secret)
+ data = {
+ 'grant_type': 'client_credentials',
+ 'scope': ' '.join(scopes),
+ }
+
+ url = '{}/token'.format(api)
+
+ print('url', url)
+ print('auth', auth)
+ print('data', data)
+ r = requests.post(url, auth=auth, data=data)
+ if not r.ok:
+ sys.exit('Error getting token: %s %s' % (r.status_code, r.text))
+
+ token = r.json()['access_token']
+ print('token', token)
+ return token
+
def unescape(s):
t = ''
@@ -101,12 +148,18 @@ def get_token(user):
return cat(filename)
-def http(vars, func, url, **kwargs):
+def http(V, func, url, **kwargs):
+ V['request'] = {
+ 'func': repr(func),
+ 'url': url,
+ 'kwargs': kwargs,
+ }
+ print('http', func, url, kwargs)
status, content_type, headers, body = func(url, **kwargs)
- vars['status_code'] = status
- vars['content_type'] = content_type
- vars['headers'] = headers
- vars['body'] = body
+ V['status_code'] = status
+ V['content_type'] = content_type
+ V['headers'] = headers
+ V['body'] = body
def get(url, token):
@@ -117,6 +170,11 @@ def get(url, token):
return r.status_code, r.headers['Content-Type'], dict(r.headers), r.text
+def get_version(url):
+ status, ctype, headers, text = get(url + '/version', 'no token')
+ assert ctype == 'application/json'
+ return json.loads(text)
+
def get_blob(url, token):
headers = {
'Authorization': 'Bearer {}'.format(token),
@@ -225,5 +283,15 @@ def list_diff(a, b):
return None
-def encode_basename(basename):
- return urllib.quote(basename, safe='')
+def expand_vars(text, variables):
+ result = ''
+ while text:
+ m = re.search(r'\${(?P<name>[^}]+)}', text)
+ if not m:
+ result += text
+ break
+ name = m.group('name')
+ print('expanding ', name)
+ result += text[:m.start()] + variables[name]
+ text = text[m.end():]
+ return result