summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/tools/sslutils/openssl.py
blob: 26ed711356d31d8597329cc151dec6bc3434fc87 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import functools
import os
import random
import shutil
import subprocess
import tempfile
from datetime import datetime

class OpenSSL(object):
    def __init__(self, logger, binary, base_path, conf_path, hosts, duration,
                 base_conf_path=None):
        """Context manager for interacting with OpenSSL.
        Creates a config file for the duration of the context.

        :param logger: stdlib logger or python structured logger
        :param binary: path to openssl binary
        :param base_path: path to directory for storing certificates
        :param conf_path: path for configuration file storing configuration data
        :param hosts: list of hosts to include in configuration (or None if not
                      generating host certificates)
        :param duration: Certificate duration in days"""

        self.base_path = base_path
        self.binary = binary
        self.conf_path = conf_path
        self.base_conf_path = base_conf_path
        self.logger = logger
        self.proc = None
        self.cmd = []
        self.hosts = hosts
        self.duration = duration

    def __enter__(self):
        with open(self.conf_path, "w") as f:
            f.write(get_config(self.base_path, self.hosts, self.duration))
        return self

    def __exit__(self, *args, **kwargs):
        os.unlink(self.conf_path)

    def log(self, line):
        if hasattr(self.logger, "process_output"):
            self.logger.process_output(self.proc.pid if self.proc is not None else None,
                                       line.decode("utf8", "replace"),
                                       command=" ".join(self.cmd))
        else:
            self.logger.debug(line)

    def __call__(self, cmd, *args, **kwargs):
        """Run a command using OpenSSL in the current context.

        :param cmd: The openssl subcommand to run
        :param *args: Additional arguments to pass to the command
        """
        self.cmd = [self.binary, cmd]
        if cmd != "x509":
            self.cmd += ["-config", self.conf_path]
        self.cmd += list(args)

        env = os.environ.copy()
        if self.base_conf_path is not None:
            env["OPENSSL_CONF"] = self.base_conf_path.encode("utf8")

        self.proc = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                                     env=env)
        stdout, stderr = self.proc.communicate()
        self.log(stdout)
        if self.proc.returncode != 0:
            raise subprocess.CalledProcessError(self.proc.returncode, self.cmd,
                                                output=stdout)

        self.cmd = []
        self.proc = None
        return stdout


def make_subject(common_name,
                 country=None,
                 state=None,
                 locality=None,
                 organization=None,
                 organization_unit=None):
    args = [("country", "C"),
            ("state", "ST"),
            ("locality", "L"),
            ("organization", "O"),
            ("organization_unit", "OU"),
            ("common_name", "CN")]

    rv = []

    for var, key in args:
        value = locals()[var]
        if value is not None:
            rv.append("/%s=%s" % (key, value.replace("/", "\\/")))

    return "".join(rv)

def make_alt_names(hosts):
    rv = []
    for name in hosts:
        rv.append("DNS:%s" % name)
    return ",".join(rv)

def get_config(root_dir, hosts, duration=30):
    if hosts is None:
        san_line = ""
    else:
        san_line = "subjectAltName = %s" % make_alt_names(hosts)

    if os.path.sep == "\\":
        # This seems to be needed for the Shining Light OpenSSL on
        # Windows, at least.
        root_dir = root_dir.replace("\\", "\\\\")

    rv = """[ ca ]
default_ca = CA_default

[ CA_default ]
dir = %(root_dir)s
certs = $dir
new_certs_dir = $certs
crl_dir = $dir%(sep)scrl
database = $dir%(sep)sindex.txt
private_key = $dir%(sep)scakey.pem
certificate = $dir%(sep)scacert.pem
serial = $dir%(sep)sserial
crldir = $dir%(sep)scrl
crlnumber = $dir%(sep)scrlnumber
crl = $crldir%(sep)scrl.pem
RANDFILE = $dir%(sep)sprivate%(sep)s.rand
x509_extensions = usr_cert
name_opt        = ca_default
cert_opt        = ca_default
default_days = %(duration)d
default_crl_days = %(duration)d
default_md = sha256
preserve = no
policy = policy_anything
copy_extensions = copy

[ policy_anything ]
countryName = optional
stateOrProvinceName = optional
localityName = optional
organizationName = optional
organizationalUnitName = optional
commonName = supplied
emailAddress = optional

[ req ]
default_bits = 2048
default_keyfile  = privkey.pem
distinguished_name = req_distinguished_name
attributes = req_attributes
x509_extensions = v3_ca

# Passwords for private keys if not present they will be prompted for
# input_password = secret
# output_password = secret
string_mask = utf8only
req_extensions = v3_req

[ req_distinguished_name ]
countryName = Country Name (2 letter code)
countryName_default = AU
countryName_min = 2
countryName_max = 2
stateOrProvinceName = State or Province Name (full name)
stateOrProvinceName_default =
localityName = Locality Name (eg, city)
0.organizationName = Organization Name
0.organizationName_default = Web Platform Tests
organizationalUnitName = Organizational Unit Name (eg, section)
#organizationalUnitName_default =
commonName = Common Name (e.g. server FQDN or YOUR name)
commonName_max = 64
emailAddress = Email Address
emailAddress_max = 64

[ req_attributes ]

[ usr_cert ]
basicConstraints=CA:false
subjectKeyIdentifier=hash
authorityKeyIdentifier=keyid,issuer

[ v3_req ]
basicConstraints = CA:FALSE
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
extendedKeyUsage = serverAuth
%(san_line)s

[ v3_ca ]
basicConstraints = CA:true
subjectKeyIdentifier=hash
authorityKeyIdentifier=keyid:always,issuer:always
keyUsage = keyCertSign
""" % {"root_dir": root_dir,
       "san_line": san_line,
       "duration": duration,
       "sep": os.path.sep.replace("\\", "\\\\")}

    return rv

class OpenSSLEnvironment(object):
    ssl_enabled = True

    def __init__(self, logger, openssl_binary="openssl", base_path=None,
                 password="web-platform-tests", force_regenerate=False,
                 duration=30, base_conf_path=None):
        """SSL environment that creates a local CA and host certificate using OpenSSL.

        By default this will look in base_path for existing certificates that are still
        valid and only create new certificates if there aren't any. This behaviour can
        be adjusted using the force_regenerate option.

        :param logger: a stdlib logging compatible logger or mozlog structured logger
        :param openssl_binary: Path to the OpenSSL binary
        :param base_path: Path in which certificates will be stored. If None, a temporary
                          directory will be used and removed when the server shuts down
        :param password: Password to use
        :param force_regenerate: Always create a new certificate even if one already exists.
        """
        self.logger = logger

        self.temporary = False
        if base_path is None:
            base_path = tempfile.mkdtemp()
            self.temporary = True

        self.base_path = os.path.abspath(base_path)
        self.password = password
        self.force_regenerate = force_regenerate
        self.duration = duration
        self.base_conf_path = base_conf_path

        self.path = None
        self.binary = openssl_binary
        self.openssl = None

        self._ca_cert_path = None
        self._ca_key_path = None
        self.host_certificates = {}

    def __enter__(self):
        if not os.path.exists(self.base_path):
            os.makedirs(self.base_path)

        path = functools.partial(os.path.join, self.base_path)

        with open(path("index.txt"), "w"):
            pass
        with open(path("serial"), "w") as f:
            serial = "%x" % random.randint(0, 1000000)
            if len(serial) % 2:
                serial = "0" + serial
            f.write(serial)

        self.path = path

        return self

    def __exit__(self, *args, **kwargs):
        if self.temporary:
            shutil.rmtree(self.base_path)

    def _config_openssl(self, hosts):
        conf_path = self.path("openssl.cfg")
        return OpenSSL(self.logger, self.binary, self.base_path, conf_path, hosts,
                       self.duration, self.base_conf_path)

    def ca_cert_path(self):
        """Get the path to the CA certificate file, generating a
        new one if needed"""
        if self._ca_cert_path is None and not self.force_regenerate:
            self._load_ca_cert()
        if self._ca_cert_path is None:
            self._generate_ca()
        return self._ca_cert_path

    def _load_ca_cert(self):
        key_path = self.path("cakey.pem")
        cert_path = self.path("cacert.pem")

        if self.check_key_cert(key_path, cert_path, None):
            self.logger.info("Using existing CA cert")
            self._ca_key_path, self._ca_cert_path = key_path, cert_path

    def check_key_cert(self, key_path, cert_path, hosts):
        """Check that a key and cert file exist and are valid"""
        if not os.path.exists(key_path) or not os.path.exists(cert_path):
            return False

        with self._config_openssl(hosts) as openssl:
            end_date_str = openssl("x509",
                                   "-noout",
                                   "-enddate",
                                   "-in", cert_path).split("=", 1)[1].strip()
            # Not sure if this works in other locales
            end_date = datetime.strptime(end_date_str, "%b %d %H:%M:%S %Y %Z")
            # Should have some buffer here e.g. 1 hr
            if end_date < datetime.now():
                return False

        #TODO: check the key actually signed the cert.
        return True

    def _generate_ca(self):
        path = self.path
        self.logger.info("Generating new CA in %s" % self.base_path)

        key_path = path("cakey.pem")
        req_path = path("careq.pem")
        cert_path = path("cacert.pem")

        with self._config_openssl(None) as openssl:
            openssl("req",
                    "-batch",
                    "-new",
                    "-newkey", "rsa:2048",
                    "-keyout", key_path,
                    "-out", req_path,
                    "-subj", make_subject("web-platform-tests"),
                    "-passout", "pass:%s" % self.password)

            openssl("ca",
                    "-batch",
                    "-create_serial",
                    "-keyfile", key_path,
                    "-passin", "pass:%s" % self.password,
                    "-selfsign",
                    "-extensions", "v3_ca",
                    "-in", req_path,
                    "-out", cert_path)

        os.unlink(req_path)

        self._ca_key_path, self._ca_cert_path = key_path, cert_path

    def host_cert_path(self, hosts):
        """Get a tuple of (private key path, certificate path) for a host,
        generating new ones if necessary.

        hosts must be a list of all hosts to appear on the certificate, with
        the primary hostname first."""
        hosts = tuple(hosts)
        if hosts not in self.host_certificates:
            if not self.force_regenerate:
                key_cert = self._load_host_cert(hosts)
            else:
                key_cert = None
            if key_cert is None:
                key, cert = self._generate_host_cert(hosts)
            else:
                key, cert = key_cert
            self.host_certificates[hosts] = key, cert

        return self.host_certificates[hosts]

    def _load_host_cert(self, hosts):
        host = hosts[0]
        key_path = self.path("%s.key" % host)
        cert_path = self.path("%s.pem" % host)

        # TODO: check that this cert was signed by the CA cert
        if self.check_key_cert(key_path, cert_path, hosts):
            self.logger.info("Using existing host cert")
            return key_path, cert_path

    def _generate_host_cert(self, hosts):
        host = hosts[0]
        if self._ca_key_path is None:
            self._generate_ca()
        ca_key_path = self._ca_key_path

        assert os.path.exists(ca_key_path)

        path = self.path

        req_path = path("wpt.req")
        cert_path = path("%s.pem" % host)
        key_path = path("%s.key" % host)

        self.logger.info("Generating new host cert")

        with self._config_openssl(hosts) as openssl:
            openssl("req",
                    "-batch",
                    "-newkey", "rsa:2048",
                    "-keyout", key_path,
                    "-in", ca_key_path,
                    "-nodes",
                    "-out", req_path)

            openssl("ca",
                    "-batch",
                    "-in", req_path,
                    "-passin", "pass:%s" % self.password,
                    "-subj", make_subject(host),
                    "-out", cert_path)

        os.unlink(req_path)

        return key_path, cert_path