# Copyright (C) 2017 Lars Wirzenius
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
import json
import os
import re
import signal
import sys
import tempfile
import time
import cliapp
import Crypto.PublicKey.RSA
import jwt
import requests
import yaml
from yarnutils import *
import qvisqve_secrets
srcdir = os.environ['SRCDIR']
datadir = os.environ['DATADIR']
V = Variables(datadir)
def hexdigit(c):
return ord(c) - ord('0')
def unescape(s):
t = ''
while s:
if s.startswith('\\x') and len(s) >= 4:
a = hexdigit(s[2])
b = hexdigit(s[3])
t += chr(a * 16 + b)
s = s[4:]
else:
t += s[0]
s = s[1:]
return t
def add_postgres_config(config):
pg = os.environ.get('QVARN_POSTGRES')
if pg:
with open(pg) as f:
config['database'] = yaml.safe_load(f)
config['memory-database'] = False
return config
def get(url, headers=None):
print('get: url={} headers={}'.format(url, headers))
r = requests.get(url, headers=headers, verify=False, allow_redirects=False)
return r.status_code, dict(r.headers), r.content
def post(url, headers=None, body=None, auth=None):
r = requests.post(
url, headers=headers, data=body, auth=auth, verify=False,
allow_redirects=False)
return r.status_code, dict(r.headers), r.text
def put(url, headers=None, body=None):
r = requests.put(
url, headers=headers, data=body, verify=False, allow_redirects=False)
return r.status_code, dict(r.headers), r.text
def delete(url, headers=None):
r = requests.delete(
url, headers=headers, verify=False, allow_redirects=False)
return r.status_code, dict(r.headers), r.text
def create_token_signing_key_pair():
RSA_KEY_BITS = 4096 # A nice, currently safe length
key = Crypto.PublicKey.RSA.generate(RSA_KEY_BITS)
return key.exportKey('PEM'), key.exportKey('OpenSSH')
def create_token(privkey, iss, aud, scopes):
filename = write_temp(privkey)
argv = [
os.path.join(srcdir, 'create-token'),
filename,
iss,
aud,
scopes,
]
return cliapp.runcmd(argv)
def cat(filename):
return open(filename).read()
def write_temp(data):
fd, filename = tempfile.mkstemp(dir=datadir)
os.write(fd, data)
os.close(fd)
return filename
def expand_vars(text, variables):
result = ''
while text:
m = re.search(r'\${(?P[^}]+)}', text)
if not m:
result += text
break
name = m.group('name')
print('expanding ', name)
result += text[:m.start()] + variables[name]
text = text[m.end():]
return result
def values_match(wanted, actual):
print
print 'wanted:', repr(wanted)
print 'actual:', repr(actual)
if type(wanted) != type(actual):
print 'wanted and actual types differ', type(wanted), type(actual)
return False
if isinstance(wanted, dict):
for key in wanted:
if key not in actual:
print 'key {!r} not in actual'.format(key)
return False
if not values_match(wanted[key], actual[key]):
return False
elif isinstance(wanted, list):
if len(wanted) != len(actual):
print 'wanted and actual are of different lengths'
for witem, aitem in zip(wanted, actual):
if not values_match(witem, aitem):
return False
else:
if wanted != actual:
print 'wanted and actual differ'
return False
return True
def start_qvisqve():
privkey, pubkey = create_token_signing_key_pair()
open('key', 'w').write(privkey)
V['aud'] = 'http://api.test.example.com'
V['privkey'] = privkey
V['pubkey'] = pubkey
V['api.log'] = 'qvisqve.log'
V['gunicorn3.log'] = 'gunicorn3.log'
V['pid-file'] = 'qvisqve.pid'
V['port'] = cliapp.runcmd([os.path.join(srcdir, 'randport' )]).strip()
V['API_URL'] = 'http://127.0.0.1:{}'.format(V['port'])
store = os.path.join(datadir, 'store')
os.mkdir(store)
os.mkdir(os.path.join(store, 'client'))
os.mkdir(os.path.join(store, 'application'))
os.mkdir(os.path.join(store, 'user'))
sh = qvisqve_secrets.SecretHasher()
if V['client_id'] and V['client_secret']:
client = {
'hashed_secret': sh.hash(V['client_secret']),
'allowed_scopes': V['allowed_scopes'],
'sub': V['sub'],
}
filename = os.path.join(store, 'client', V['client_id'])
with open(filename, 'w') as f:
yaml.safe_dump(client, stream=f)
apps = V['applications']
for name in apps or []:
filename = os.path.join(store, 'application', name)
spec = {
'callbacks': [apps[name]],
}
with open(filename, 'w') as f:
yaml.safe_dump(spec, stream=f)
users = V['users'] or {}
print('users:', users)
for name, user in users.items():
print('add user', name, user)
filename = os.path.join(store, 'user', name)
spec = {
'hashed_secret': sh.hash(user['password']),
'allowed_scopes': user['scopes'],
}
with open(filename, 'w') as f:
yaml.safe_dump(spec, stream=f)
config = {
'gunicorn': 'background',
'gunicorn-log': 'gunicorn.log',
'gunicorn-pid-file': V['pid-file'],
'gunicorn-port': V['port'],
'log': [
{
'filename': V['api.log'],
},
],
'token-private-key': V['privkey'],
'token-public-key': V['pubkey'],
'token-issuer': V['iss'],
'token-lifetime': 3600,
'store': store,
}
env = dict(os.environ)
env['QVISQVE_CONFIG'] = os.path.join(datadir, 'qvisqve.yaml')
env['QVISQVE_STARTUP_LOG'] = os.path.join(datadir, 'startup.log')
yaml.safe_dump(config, open(env['QVISQVE_CONFIG'], 'w'))
argv = [
os.path.join(srcdir, 'start_qvisqve'),
env['QVISQVE_CONFIG'],
]
cliapp.runcmd(argv, env=env, stdout=None, stderr=None)
until = time.time() + 2.0
while time.time() < until and not os.path.exists(V['pid-file']):
time.sleep(0.01)
assert os.path.exists(V['pid-file'])
def stop_qvisqve():
filename = V['pid-file']
if os.path.exists(filename):
pid = int(cat(filename))
os.kill(pid, signal.SIGTERM)
def get_token(client_id, client_secret, scopes):
url = '{}/token'.format(V['API_URL'])
auth = (client_id, client_secret)
data = {
'grant_type': 'client_credentials',
'scope': ' '.join(scopes),
}
r = requests.post(
url, auth=auth, data=data, verify=False, allow_redirects=False)
return r.status_code, dict(r.headers), r.text
def token_decode(token, pubkey):
key = Crypto.PublicKey.RSA.importKey(pubkey)
audience = V['aud']
print('audience', repr(audience))
try:
return jwt.decode(
token, key=key.exportKey('OpenSSH'), audience=audience,
options={'verify_aud': False})
except jwt.exceptions.InvalidTokenError as e:
print('invalid token error', str(e))
return None