summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLars Wirzenius <liw@liw.fi>2022-10-06 09:33:50 +0000
committerLars Wirzenius <liw@liw.fi>2022-10-06 09:33:50 +0000
commitd0c872ca1fb1ffe18939d2243fd802a9b5dcffa5 (patch)
tree6521cc6c628627e6cda955a4801057c95e266271
parent024b1f55d6ff090a38ba697d410165f59a0f9ed6 (diff)
parent36b2ebd1643833700e57c51523d8c9c66f3d0034 (diff)
downloadvmadm-d0c872ca1fb1ffe18939d2243fd802a9b5dcffa5.tar.gz
Merge branch 'sshd-config' into 'main'
refactor and more: move Python script for cloud-init out of Rust See merge request larswirzenius/vmadm!61
-rw-r--r--script.py162
-rw-r--r--src/cloudinit.rs93
2 files changed, 163 insertions, 92 deletions
diff --git a/script.py b/script.py
new file mode 100644
index 0000000..c25a8f2
--- /dev/null
+++ b/script.py
@@ -0,0 +1,162 @@
+import glob
+import os
+import yaml
+
+
+class CloudInit:
+ def __init__(self, logfile, etc, user_data):
+ self.logfile = open(logfile, "a")
+ self.etc = etc
+ self.user_data = user_data
+
+ self.sshd_config = self.ssh_join("sshd_config")
+ self.user_ca_pubs = self.ssh_join("user_ca_pubs")
+
+ self.host_id_conf = self.dotd_join("host_id.conf")
+ self.user_ca_conf = self.dotd_join("user_ca.conf")
+ self.authz_keys_conf = self.dotd_join("authorized_keys.conf")
+
+ self.key_types = ("rsa", "dsa", "ecdsa", "ed25519")
+ self.key_files = [self.ssh_join(f"ssh_host_{type}") for type in self.key_types]
+
+ self.user_data_obj = None
+
+ self.log("vmadm cloud-init script starting")
+ self.log(f"etc={self.etc}")
+ self.log(f"user_data={self.user_data}")
+ self.log(f"sshd_config={self.sshd_config}")
+ self.log(f"host_id_conf={self.host_id_conf}")
+ self.log(f"user_ca_conf={self.user_ca_conf}")
+
+ def __del__(self):
+ self.log("vmadm cloud-init script ending")
+ self.logfile.close()
+ self.logfile = None
+
+ def ssh_join(self, path):
+ return os.path.join(self.etc, "ssh", path)
+
+ def dotd_join(self, path):
+ return os.path.join(self.etc, "ssh", "sshd_config.d", path)
+
+ def log(self, msg):
+ self.logfile.write(f"{msg}\n")
+ self.logfile.flush()
+
+ def read(self, filename):
+ self.log(f"reading {filename}")
+ with open(filename) as f:
+ return f.read()
+
+ def write(self, filename, data):
+ self.log(f"writing {filename}")
+ dirname = os.path.dirname(filename)
+ if not os.path.exists(dirname):
+ self.log(f"mkdir {dirname}")
+ os.mkdir(dirname)
+ with open(filename, "w") as f:
+ f.write(data)
+
+ def load_user_data(self):
+ self.log(f"loading user-data from {self.user_data}")
+ self.user_data_obj = yaml.safe_load(open(self.user_data))
+ self.log(f"loaded user data:")
+ for (key, value) in self.user_data_obj.items():
+ self.log(f" {key}={value}")
+
+ def ssh_keys(self):
+ return self.user_data_obj.get("ssh_keys", {})
+
+ def user_ca_public_key(self):
+ return self.user_data_obj.get("user_ca_pubkey")
+
+ def allow_authorized_keys(self):
+ return self.user_data_obj.get("allow_authorized_keys", True)
+
+ def add_include(self):
+ data = ""
+ if os.path.exists(self.sshd_config):
+ data = self.read(self.sshd_config)
+ include = "Include /etc/ssh/sshd_config.d/*.conf"
+ if include.lower() not in data.lower():
+ self.log(f"adding Include for .d to {self.sshd_config}")
+ self.write(self.sshd_config, f"{include}\n{data}")
+
+ def remove_host_keys(self):
+ self.log("removing host keys")
+ for filename in glob.glob(self.ssh_join("ssh_host_*")):
+ if os.path.exists(filename):
+ self.log(f"removing {filename}")
+ os.remove(filename)
+
+ def write_host_keys(self):
+ keys = []
+ certs = []
+
+ ssh_keys = self.user_data_obj.get("ssh_keys")
+ for key_type in self.key_types:
+ key = ssh_keys.get(f"{key_type}_private")
+ cert = ssh_keys.get(f"{key_type}_certificate")
+ self.log(f"key {key_type} {key}")
+ self.log(f"cert {key_type} {cert }")
+
+ if key:
+ filename = self.ssh_join(f"ssh_host_{key_type}_key")
+ self.log(f"writing key {filename}")
+ keys.append(filename)
+ self.write(filename, key)
+
+ if cert:
+ filename = self.ssh_join(f"ssh_host_{key_type}_key-cert.pub")
+ self.log(f"writing cert {filename}")
+ certs.append(filename)
+ self.write(filaneme, cert)
+
+ return keys, certs
+
+ def configure_host_id(self, keys, certs):
+ host_id = []
+ for filename in keys:
+ host_id.append(f"HostKey {filename}")
+ for filename in certs:
+ host_id.append(f"HostCertificate {filename}")
+ host_id = "".join(f"{line}\n" for line in host_id)
+ self.write(self.host_id_conf, host_id)
+
+ def configure_user_ca(self):
+ pub = self.user_ca_public_key()
+ if pub:
+ pub = self.ssh_join(pub)
+ self.write(self.user_ca_conf, f"trustedusercakeys {pub}")
+
+ def configure_authorized_keys(self):
+ if not self.allow_authorized_keys():
+ self.write(self.authz_keys_conf, "authorizedkeysfile none\n")
+
+
+def main():
+ if os.environ.get("VMADM_TESTING"):
+ logfile = "vmadm.script"
+ user_data = "smoke/user-data"
+ etc = "x"
+ else:
+ logfile = "/tmp/vmadm.script"
+ user_data = "/var/lib/cloud/instance/user-data.txt"
+ etc = "/etc/ssh"
+
+ init = CloudInit(logfile, etc, user_data)
+ try:
+ init.load_user_data()
+ init.add_include()
+ init.remove_host_keys()
+ (keys, certs) = init.write_host_keys()
+ init.configure_host_id(keys, certs)
+ init.configure_user_ca()
+ init.configure_authorized_keys()
+ init.log("all good")
+ except BaseException as e:
+ init.log(f"error: {e}")
+ return
+
+
+main()
diff --git a/src/cloudinit.rs b/src/cloudinit.rs
index 9d14538..0cbd2f0 100644
--- a/src/cloudinit.rs
+++ b/src/cloudinit.rs
@@ -17,98 +17,7 @@ use std::path::{Path, PathBuf};
use std::process::Command;
use tempfile::tempdir;
-const SCRIPT: &str = r#"
-import os
-import yaml
-
-
-def log(msg):
- logfile.write(msg)
- logfile.write("\n")
- logfile.flush()
-
-
-logfile = open("/tmp/vmadm.script", "w")
-log("vmadm cloud-init script starting")
-
-if os.environ.get("VMADM_TESTING"):
- filename = "smoke/user-data"
- etc = "x"
-else:
- filename = "/var/lib/cloud/instance/user-data.txt"
- etc = "/etc/ssh"
-
-key_types = ("rsa", "dsa", "ecdsa", "ed25519")
-
-log(f"loading user-data from {filename}")
-obj = yaml.safe_load(open(filename))
-
-ssh_keys = obj.get("ssh_keys", {})
-user_ca_pubkey = obj.get("user_ca_pubkey", {})
-allow_authorized_keys = obj.get("allow_authorized_keys", True)
-
-keys = []
-certs = []
-
-for key_type in key_types:
- filename = os.path.join(etc, f"ssh_host_{key_type}_key.pub")
- if os.path.exists(filename):
- log(f"removing {filename}")
- os.remove(filename)
- else:
- log(f"file {filename} does not exist")
-
-for key_type in key_types:
- key = ssh_keys.get(f"{key_type}_private")
- cert = ssh_keys.get(f"{key_type}_certificate")
- log(f"key {key_type} {key}")
- log(f"cert {key_type} {cert }")
-
- if key:
- filename = os.path.join(etc, f"ssh_host_{key_type}_key")
- log(f"writing key {filename}")
- keys.append(filename)
- with open(filename, "w") as f:
- f.write(key)
-
- if cert:
- filename = os.path.join(etc, f"ssh_host_{key_type}_key-cert.pub")
- log(f"writing cert {filename}")
- certs.append(filename)
- with open(filename, "w") as f:
- f.write(cert)
-
-user_ca_filename = os.path.join(etc, "user-ca-keys")
-if user_ca_pubkey:
- with open(user_ca_filename, "w") as f:
- f.write(user_ca_pubkey)
-
-config = os.path.join(etc, "sshd_config")
-data = ""
-if os.path.exists(config):
- data = open(config).read()
-
-log(f"configuring sshd {config}")
-log(f"keys {keys}")
-log(f"certs {certs}")
-
-with open(config, "w") as f:
- for filename in keys:
- log(f"hostkey {filename}")
- f.write(f"hostkey {filename}\n")
- for filename in certs:
- log(f"hostcert {filename}")
- f.write(f"hostcertificate {filename}\n")
- if user_ca_pubkey:
- log(f"trustedusercakeys {user_ca_filename}")
- f.write(f"trustedusercakeys {user_ca_filename}\n")
- if not allow_authorized_keys:
- f.write("authorizedkeysfile none\n")
- f.write(data)
-
-log("vmadm cloud-init script ending")
-logfile.close()
-"#;
+const SCRIPT: &str = include_str!("../script.py");
/// Errors from this module.
#[derive(Debug, thiserror::Error)]