# Copyright 2016 Lars Wirzenius # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # =*= License: GPL-3+ =*= import os import urlparse import requests import yaml variables_filename = os.environ.get('VARIABLES', 'vars.yaml') class YarnHelper(object): def __init__(self): self._env = dict(os.environ) self._next_match = 1 self._variables = None # None means not loaded, otherwise dict def set_environment(self, env): self._env = dict(env) def get_next_match(self): name = 'MATCH_{}'.format(self._next_match) if name not in self._env: raise Error('no next match') self._next_match += 1 return self._env[name] def get_variable(self, name): if self._variables is None: self._variables = self._load_variables() if name not in self._variables: raise Error('no variable {}'.format(name)) return self._variables[name] def _load_variables(self): if os.path.exists(variables_filename): with open(variables_filename, 'r') as f: return yaml.safe_load(f) return {} def set_variable(self, name, value): if self._variables is None: self._variables = {} self._variables[name] = value self._save_variables(self._variables) def _save_variables(self, variables): with open(variables_filename, 'w') as f: yaml.safe_dump(variables, f) def construct_aliased_http_request( self, server, method, url, data=None, headers=None): if headers is None: headers = {} parts = list(urlparse.urlparse(url)) headers['Host'] = parts[1] parts[1] = server aliased_url = urlparse.urlunparse(parts) r = requests.Request(method, aliased_url, data=data, headers=headers) return r.prepare() def http_get(self, server, url): # pragma: no cover r = self.construct_aliased_http_request(server, 'GET', url) s = requests.Session() resp = s.send(r) return resp.status_code, resp.content def assertEqual(self, a, b): if a != b: raise Error('assertion {!r} == {!r} failed'.format(a, b)) def assertNotEqual(self, a, b): if a == b: raise Error('assertion {!r} != {!r} failed'.format(a, b)) class Error(Exception): pass