diff options
Diffstat (limited to 'script.py')
-rw-r--r-- | script.py | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/script.py b/script.py new file mode 100644 index 0000000..c25a8f2 --- /dev/null +++ b/script.py @@ -0,0 +1,162 @@ +import glob +import os +import yaml + + +class CloudInit: + def __init__(self, logfile, etc, user_data): + self.logfile = open(logfile, "a") + self.etc = etc + self.user_data = user_data + + self.sshd_config = self.ssh_join("sshd_config") + self.user_ca_pubs = self.ssh_join("user_ca_pubs") + + self.host_id_conf = self.dotd_join("host_id.conf") + self.user_ca_conf = self.dotd_join("user_ca.conf") + self.authz_keys_conf = self.dotd_join("authorized_keys.conf") + + self.key_types = ("rsa", "dsa", "ecdsa", "ed25519") + self.key_files = [self.ssh_join(f"ssh_host_{type}") for type in self.key_types] + + self.user_data_obj = None + + self.log("vmadm cloud-init script starting") + self.log(f"etc={self.etc}") + self.log(f"user_data={self.user_data}") + self.log(f"sshd_config={self.sshd_config}") + self.log(f"host_id_conf={self.host_id_conf}") + self.log(f"user_ca_conf={self.user_ca_conf}") + + def __del__(self): + self.log("vmadm cloud-init script ending") + self.logfile.close() + self.logfile = None + + def ssh_join(self, path): + return os.path.join(self.etc, "ssh", path) + + def dotd_join(self, path): + return os.path.join(self.etc, "ssh", "sshd_config.d", path) + + def log(self, msg): + self.logfile.write(f"{msg}\n") + self.logfile.flush() + + def read(self, filename): + self.log(f"reading {filename}") + with open(filename) as f: + return f.read() + + def write(self, filename, data): + self.log(f"writing {filename}") + dirname = os.path.dirname(filename) + if not os.path.exists(dirname): + self.log(f"mkdir {dirname}") + os.mkdir(dirname) + with open(filename, "w") as f: + f.write(data) + + def load_user_data(self): + self.log(f"loading user-data from {self.user_data}") + self.user_data_obj = yaml.safe_load(open(self.user_data)) + self.log(f"loaded user data:") + for (key, value) in self.user_data_obj.items(): + self.log(f" {key}={value}") + + def ssh_keys(self): + return self.user_data_obj.get("ssh_keys", {}) + + def user_ca_public_key(self): + return self.user_data_obj.get("user_ca_pubkey") + + def allow_authorized_keys(self): + return self.user_data_obj.get("allow_authorized_keys", True) + + def add_include(self): + data = "" + if os.path.exists(self.sshd_config): + data = self.read(self.sshd_config) + include = "Include /etc/ssh/sshd_config.d/*.conf" + if include.lower() not in data.lower(): + self.log(f"adding Include for .d to {self.sshd_config}") + self.write(self.sshd_config, f"{include}\n{data}") + + def remove_host_keys(self): + self.log("removing host keys") + for filename in glob.glob(self.ssh_join("ssh_host_*")): + if os.path.exists(filename): + self.log(f"removing {filename}") + os.remove(filename) + + def write_host_keys(self): + keys = [] + certs = [] + + ssh_keys = self.user_data_obj.get("ssh_keys") + for key_type in self.key_types: + key = ssh_keys.get(f"{key_type}_private") + cert = ssh_keys.get(f"{key_type}_certificate") + self.log(f"key {key_type} {key}") + self.log(f"cert {key_type} {cert }") + + if key: + filename = self.ssh_join(f"ssh_host_{key_type}_key") + self.log(f"writing key {filename}") + keys.append(filename) + self.write(filename, key) + + if cert: + filename = self.ssh_join(f"ssh_host_{key_type}_key-cert.pub") + self.log(f"writing cert {filename}") + certs.append(filename) + self.write(filaneme, cert) + + return keys, certs + + def configure_host_id(self, keys, certs): + host_id = [] + for filename in keys: + host_id.append(f"HostKey {filename}") + for filename in certs: + host_id.append(f"HostCertificate {filename}") + host_id = "".join(f"{line}\n" for line in host_id) + self.write(self.host_id_conf, host_id) + + def configure_user_ca(self): + pub = self.user_ca_public_key() + if pub: + pub = self.ssh_join(pub) + self.write(self.user_ca_conf, f"trustedusercakeys {pub}") + + def configure_authorized_keys(self): + if not self.allow_authorized_keys(): + self.write(self.authz_keys_conf, "authorizedkeysfile none\n") + + +def main(): + if os.environ.get("VMADM_TESTING"): + logfile = "vmadm.script" + user_data = "smoke/user-data" + etc = "x" + else: + logfile = "/tmp/vmadm.script" + user_data = "/var/lib/cloud/instance/user-data.txt" + etc = "/etc/ssh" + + init = CloudInit(logfile, etc, user_data) + try: + init.load_user_data() + init.add_include() + init.remove_host_keys() + (keys, certs) = init.write_host_keys() + init.configure_host_id(keys, certs) + init.configure_user_ca() + init.configure_authorized_keys() + init.log("all good") + except BaseException as e: + init.log(f"error: {e}") + return + + +main() |