summaryrefslogtreecommitdiff
path: root/script.py
blob: 0bca4867eeb965c7a9dfc6e63a2fee044a85b304 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import glob
import os
import yaml


class CloudInit:
    def __init__(self, logfile, etc, user_data):
        self.logfile = open(logfile, "a")
        self.etc = etc
        self.user_data = user_data

        self.sshd_config = self.ssh_join("sshd_config")
        self.user_ca_pubs = self.ssh_join("user_ca_pubs")

        self.host_id_conf = self.dotd_join("host_id.conf")
        self.user_ca_conf = self.dotd_join("user_ca.conf")
        self.authz_keys_conf = self.dotd_join("authorized_keys.conf")

        self.key_types = ("rsa", "dsa", "ecdsa", "ed25519")
        self.key_files = [self.ssh_join(f"ssh_host_{type}") for type in self.key_types]

        self.user_data_obj = None

        self.log("vmadm cloud-init script starting")
        self.log(f"etc={self.etc}")
        self.log(f"user_data={self.user_data}")
        self.log(f"sshd_config={self.sshd_config}")
        self.log(f"host_id_conf={self.host_id_conf}")
        self.log(f"user_ca_conf={self.user_ca_conf}")

    def __del__(self):
        self.log("vmadm cloud-init script ending")
        self.logfile.close()
        self.logfile = None

    def ssh_join(self, path):
        return os.path.join(self.etc, "ssh", path)

    def dotd_join(self, path):
        return os.path.join(self.etc, "ssh", "sshd_config.d", path)

    def log(self, msg):
        self.logfile.write(f"{msg}\n")
        self.logfile.flush()

    def read(self, filename):
        self.log(f"reading {filename}")
        with open(filename) as f:
            return f.read()

    def write(self, filename, data):
        self.log(f"writing {filename}")
        dirname = os.path.dirname(filename)
        if not os.path.exists(dirname):
            self.log(f"mkdir {dirname}")
            os.mkdir(dirname)
        with open(filename, "w") as f:
            f.write(data)

    def load_user_data(self):
        self.log(f"loading user-data from {self.user_data}")
        self.user_data_obj = yaml.safe_load(open(self.user_data))
        self.log(f"loaded user data:")
        for (key, value) in self.user_data_obj.items():
            self.log(f"  {key}={value}")

    def ssh_keys(self):
        return self.user_data_obj.get("ssh_keys", {})

    def user_ca_public_key(self):
        return self.user_data_obj.get("user_ca_pubkey")

    def allow_authorized_keys(self):
        return self.user_data_obj.get("allow_authorized_keys", True)

    def add_include(self):
        data = ""
        if os.path.exists(self.sshd_config):
            data = self.read(self.sshd_config)
        include = "Include /etc/ssh/sshd_config.d/*.conf"
        if include.lower() not in data.lower():
            self.log(f"adding Include for .d to {self.sshd_config}")
            self.write(self.sshd_config, f"{include}\n{data}")

    def remove_host_keys(self):
        self.log("removing host keys")
        for filename in glob.glob(self.ssh_join("ssh_host_*")):
            if os.path.exists(filename):
                self.log(f"removing {filename}")
                os.remove(filename)

    def write_host_keys(self):
        keys = []
        certs = []

        ssh_keys = self.user_data_obj.get("ssh_keys")
        for key_type in self.key_types:
            key = ssh_keys.get(f"{key_type}_private")
            cert = ssh_keys.get(f"{key_type}_certificate")
            self.log(f"key {key_type} {key}")
            self.log(f"cert {key_type} {cert }")

            if key:
                filename = self.ssh_join(f"ssh_host_{key_type}_key")
                self.log(f"writing key {filename}")
                keys.append(filename)
                self.write(filename, key)

            if cert:
                filename = self.ssh_join(f"ssh_host_{key_type}_key-cert.pub")
                self.log(f"writing cert {filename}")
                certs.append(filename)
                self.write(filaneme, cert)

        return keys, certs

    def configure_host_id(self, keys, certs):
        host_id = []
        for filename in keys:
            host_id.append(f"HostKey {filename}")
        for filename in certs:
            host_id.append(f"HostCertificate {filename}")
        host_id = "".join(f"{line}\n" for line in host_id)
        self.write(self.host_id_conf, host_id)

    def configure_user_ca(self):
        pub = self.user_ca_public_key()
        if pub:
            pub = self.ssh_join(pub)
            self.write(self.user_ca_conf, f"trustedusercakeys {pub}")

    def configure_authorized_keys(self):
        if not self.allow_authorized_keys():
            self.write(self.authz_keys_conf, "authorizedkeysfile none\n")


def main():
    if os.environ.get("VMADM_TESTING"):
        logfile = "vmadm.script"
        user_data = "smoke/user-data"
        etc = "x"
    else:
        logfile = "/tmp/vmadm.script"
        user_data = "/var/lib/cloud/instance/user-data.txt"
        etc = "/etc"

    init = CloudInit(logfile, etc, user_data)
    try:
        init.load_user_data()
        init.add_include()
        init.remove_host_keys()
        (keys, certs) = init.write_host_keys()
        init.configure_host_id(keys, certs)
        init.configure_user_ca()
        init.configure_authorized_keys()
        init.log("all good")
    except BaseException as e:
        init.log(f"error: {e}")
    return


main()