diff --git a/hyper/common/exceptions.py b/hyper/common/exceptions.py index be15fc97..268431ab 100644 --- a/hyper/common/exceptions.py +++ b/hyper/common/exceptions.py @@ -64,3 +64,10 @@ def __init__(self, negotiated, sock): super(HTTPUpgrade, self).__init__() self.negotiated = negotiated self.sock = sock + + +class MissingCertFile(Exception): + """ + The certificate file could not be found. + """ + pass diff --git a/hyper/tls.py b/hyper/tls.py index 44cc8be6..422b001c 100644 --- a/hyper/tls.py +++ b/hyper/tls.py @@ -6,7 +6,7 @@ Contains the TLS/SSL logic for use in hyper. """ import os.path as path - +from .common.exceptions import MissingCertFile from .compat import ignore_missing, ssl @@ -29,14 +29,17 @@ def wrap_socket(sock, server_hostname, ssl_context=None, force_proto=None): A vastly simplified SSL wrapping function. We'll probably extend this to do more things later. """ - global _context - # create the singleton SSLContext we use - if _context is None: # pragma: no cover - _context = init_context() + global _context - # if an SSLContext is provided then use it instead of default context - _ssl_context = ssl_context or _context + if ssl_context: + # if an SSLContext is provided then use it instead of default context + _ssl_context = ssl_context + else: + # create the singleton SSLContext we use + if _context is None: # pragma: no cover + _context = init_context() + _ssl_context = _context # the spec requires SNI support ssl_sock = _ssl_context.wrap_socket(sock, server_hostname=server_hostname) @@ -94,9 +97,17 @@ def init_context(cert_path=None, cert=None, cert_password=None): encrypted and no password is needed. :returns: An ``SSLContext`` correctly set up for HTTP/2. """ + cafile = cert_path or cert_loc + if not cafile or not path.exists(cafile): + err_msg = ("No certificate found at " + str(cafile) + ". Either " + + "ensure the default cert.pem file is included in the " + + "distribution or provide a custom certificate when " + + "creating the connection.") + raise MissingCertFile(err_msg) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) context.set_default_verify_paths() - context.load_verify_locations(cafile=cert_path or cert_loc) + context.load_verify_locations(cafile=cafile) context.verify_mode = ssl.CERT_REQUIRED context.check_hostname = True diff --git a/test/test_SSLContext.py b/test/test_SSLContext.py index b03975d4..4add16f3 100644 --- a/test/test_SSLContext.py +++ b/test/test_SSLContext.py @@ -14,6 +14,7 @@ CLIENT_CERT_FILE = os.path.join(TEST_CERTS_DIR, 'client.crt') CLIENT_KEY_FILE = os.path.join(TEST_CERTS_DIR, 'client.key') CLIENT_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'nopassword.pem') +MISSING_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'missing.pem') class TestSSLContext(object): @@ -60,3 +61,17 @@ def test_client_certificates(self): cert=(CLIENT_CERT_FILE, CLIENT_KEY_FILE), cert_password=b'abc123') hyper.tls.init_context(cert=CLIENT_PEM_FILE) + + def test_missing_certs(self): + succeeded = False + threw_expected_exception = False + try: + hyper.tls.init_context(MISSING_PEM_FILE) + succeeded = True + except hyper.common.exceptions.MissingCertFile: + threw_expected_exception = True + except: + pass + + assert not succeeded + assert threw_expected_exception