# Copyright (C) 2017 Lars Wirzenius # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . import json import os import re import signal import sys import tempfile import time import cliapp import Crypto.PublicKey.RSA import jwt import requests import yaml from yarnutils import * import qvisqve_secrets srcdir = os.environ['SRCDIR'] datadir = os.environ['DATADIR'] V = Variables(datadir) def hexdigit(c): return ord(c) - ord('0') def unescape(s): t = '' while s: if s.startswith('\\x') and len(s) >= 4: a = hexdigit(s[2]) b = hexdigit(s[3]) t += chr(a * 16 + b) s = s[4:] else: t += s[0] s = s[1:] return t def add_postgres_config(config): pg = os.environ.get('QVARN_POSTGRES') if pg: with open(pg) as f: config['database'] = yaml.safe_load(f) config['memory-database'] = False return config def get(url, headers=None): print('get: url={} headers={}'.format(url, headers)) r = requests.get(url, headers=headers, verify=False, allow_redirects=False) return r.status_code, dict(r.headers), r.content def post(url, headers=None, body=None, auth=None): r = requests.post( url, headers=headers, data=body, auth=auth, verify=False, allow_redirects=False) return r.status_code, dict(r.headers), r.text def put(url, headers=None, body=None): r = requests.put( url, headers=headers, data=body, verify=False, allow_redirects=False) return r.status_code, dict(r.headers), r.text def delete(url, headers=None): r = requests.delete( url, headers=headers, verify=False, allow_redirects=False) return r.status_code, dict(r.headers), r.text def create_token_signing_key_pair(): RSA_KEY_BITS = 4096 # A nice, currently safe length key = Crypto.PublicKey.RSA.generate(RSA_KEY_BITS) return key.exportKey('PEM'), key.exportKey('OpenSSH') def create_token(privkey, iss, aud, scopes): filename = write_temp(privkey) argv = [ os.path.join(srcdir, 'create-token'), filename, iss, aud, scopes, ] return cliapp.runcmd(argv) def cat(filename): return open(filename).read() def write_temp(data): fd, filename = tempfile.mkstemp(dir=datadir) os.write(fd, data) os.close(fd) return filename def expand_vars(text, variables): result = '' while text: m = re.search(r'\${(?P[^}]+)}', text) if not m: result += text break name = m.group('name') print('expanding ', name) result += text[:m.start()] + variables[name] text = text[m.end():] return result def values_match(wanted, actual): print print 'wanted:', repr(wanted) print 'actual:', repr(actual) if type(wanted) != type(actual): print 'wanted and actual types differ', type(wanted), type(actual) return False if isinstance(wanted, dict): for key in wanted: if key not in actual: print 'key {!r} not in actual'.format(key) return False if not values_match(wanted[key], actual[key]): return False elif isinstance(wanted, list): if len(wanted) != len(actual): print 'wanted and actual are of different lengths' for witem, aitem in zip(wanted, actual): if not values_match(witem, aitem): return False else: if wanted != actual: print 'wanted and actual differ' return False return True def start_qvisqve(): privkey, pubkey = create_token_signing_key_pair() open('key', 'w').write(privkey) V['aud'] = 'http://api.test.example.com' V['privkey'] = privkey V['pubkey'] = pubkey V['api.log'] = 'qvisqve.log' V['gunicorn3.log'] = 'gunicorn3.log' V['pid-file'] = 'qvisqve.pid' V['port'] = cliapp.runcmd([os.path.join(srcdir, 'randport' )]).strip() V['API_URL'] = 'http://127.0.0.1:{}'.format(V['port']) store = os.path.join(datadir, 'store') os.mkdir(store) os.mkdir(os.path.join(store, 'client')) os.mkdir(os.path.join(store, 'application')) os.mkdir(os.path.join(store, 'user')) sh = qvisqve_secrets.SecretHasher() if V['client_id'] and V['client_secret']: client = { 'hashed_secret': sh.hash(V['client_secret']), 'allowed_scopes': V['allowed_scopes'], 'sub': V['sub'], } filename = os.path.join(store, 'client', V['client_id']) with open(filename, 'w') as f: yaml.safe_dump(client, stream=f) apps = V['applications'] for name in apps or []: filename = os.path.join(store, 'application', name) spec = { 'callbacks': [apps[name]], } with open(filename, 'w') as f: yaml.safe_dump(spec, stream=f) users = V['users'] or {} print('users:', users) for name, user in users.items(): print('add user', name, user) filename = os.path.join(store, 'user', name) spec = { 'hashed_secret': sh.hash(user['password']), 'allowed_scopes': user['scopes'], } with open(filename, 'w') as f: yaml.safe_dump(spec, stream=f) config = { 'gunicorn': 'background', 'gunicorn-log': 'gunicorn.log', 'gunicorn-pid-file': V['pid-file'], 'gunicorn-port': V['port'], 'log': [ { 'filename': V['api.log'], }, ], 'token-private-key': V['privkey'], 'token-public-key': V['pubkey'], 'token-issuer': V['iss'], 'token-lifetime': 3600, 'store': store, } env = dict(os.environ) env['QVISQVE_CONFIG'] = os.path.join(datadir, 'qvisqve.yaml') env['QVISQVE_STARTUP_LOG'] = os.path.join(datadir, 'startup.log') yaml.safe_dump(config, open(env['QVISQVE_CONFIG'], 'w')) argv = [ os.path.join(srcdir, 'start_qvisqve'), env['QVISQVE_CONFIG'], ] cliapp.runcmd(argv, env=env, stdout=None, stderr=None) until = time.time() + 2.0 while time.time() < until and not os.path.exists(V['pid-file']): time.sleep(0.01) assert os.path.exists(V['pid-file']) def stop_qvisqve(): filename = V['pid-file'] if os.path.exists(filename): pid = int(cat(filename)) os.kill(pid, signal.SIGTERM) def get_token(client_id, client_secret, scopes): url = '{}/token'.format(V['API_URL']) auth = (client_id, client_secret) data = { 'grant_type': 'client_credentials', 'scope': ' '.join(scopes), } r = requests.post( url, auth=auth, data=data, verify=False, allow_redirects=False) return r.status_code, dict(r.headers), r.text def token_decode(token, pubkey): key = Crypto.PublicKey.RSA.importKey(pubkey) audience = V['aud'] print('audience', repr(audience)) try: return jwt.decode( token, key=key.exportKey('OpenSSH'), audience=audience, options={'verify_aud': False}) except jwt.exceptions.InvalidTokenError as e: print('invalid token error', str(e)) return None