Rework https checker
authorChristoph Egger <christoph@christoph-egger.org>
Sun, 2 Apr 2017 18:50:10 +0000 (20:50 +0200)
committerChristoph Egger <christoph@christoph-egger.org>
Sun, 2 Apr 2017 18:50:10 +0000 (20:50 +0200)
check_dane/__init__.py [new file with mode: 0644]
check_dane/abstract.py [new file with mode: 0644]
check_dane/https.py [new file with mode: 0644]
check_dane/tlsa.py
setup.py

diff --git a/check_dane/__init__.py b/check_dane/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/check_dane/abstract.py b/check_dane/abstract.py
new file mode 100644 (file)
index 0000000..1373ed0
--- /dev/null
@@ -0,0 +1,106 @@
+from abc import ABCMeta, abstractmethod
+from unbound import ub_ctx
+from socket import socket, AF_INET6, AF_INET
+from ssl import SSLContext, PROTOCOL_TLSv1_2, CERT_REQUIRED
+
+
+from check_dane.cert import verify_certificate, add_certificate_options
+from check_dane.tlsa import get_tlsa_records, match_tlsa_records
+
+
+class DaneWarning:
+    pass
+
+class DaneError:
+    pass
+
+
+class DaneChecker:
+    def __init__(self):
+        pass
+
+
+    @abstractmethod
+    def _init_connection(self):
+        pass
+
+
+    @abstractmethod
+    def _close_connection(self):
+        pass
+
+
+    @property
+    @abstractmethod
+    def port(self):
+        pass
+
+    
+    def _gather_certificates(self):
+        retval = 0
+        certificates = set()
+        for afamily in self._afamilies:
+            try:
+                connection = self._init_connection(afamily, self._host, self.port)
+            except ConnectionRefusedError:
+                logging.error("Connection refused")
+                return 2
+
+            nretval = verify_certificate(connection.getpeercert(), self._args)
+            retval = max(retval, nretval)
+            certificates.add(connection.getpeercert(binary_form=True))
+
+            self._close_connection(connection)
+
+        return certificates
+    
+    
+    def _gather_records(self):
+        return get_tlsa_records(self._resolver, "_%d._tcp.%s" % (self.port, self._host))
+
+        
+    def generate_menu(self, argparser):
+        argparser.add_argument("Host")
+
+        argparser.add_argument("--check-dane",
+                            action="store_false",
+                            help="Verify presented certificate via DANE (default: enabled)")
+        argparser.add_argument("--check-ca",
+                            action="store_false",
+                            help="Verify presented certificate via the CA system (default: enabled)")
+        argparser.add_argument("--check-expire",
+                            action="store_false",
+                            help="Verify presented certificate for expiration (default: enabled)")
+
+        argparser.add_argument("-a", "--ancor",
+                            action="store", type=str, default="/usr/share/dns/root.key",
+                            help="DNSSEC root ancor")
+        argparser.add_argument("--castore", action="store", type=str,
+                            default="/etc/ssl/certs/ca-certificates.crt",
+                            help="ca certificate bundle")
+
+        group = argparser.add_mutually_exclusive_group()
+        group.add_argument("-6", "--6", action="store_true", dest="use6", help="check via IPv6 only")
+        group.add_argument("-4", "--4", action="store_true", dest="use4", help="check via IPv4 only")
+
+
+    def set_args(self, args):        
+        self._args = args
+        resolver = ub_ctx()
+        resolver.add_ta_file(args.ancor)
+        self._resolver = resolver
+
+        if args.use6:
+            self._afamilies = [AF_INET6]
+        elif args.use4:
+            self._afamilies = [AF_INET]
+        else:
+            self._afamilies = [AF_INET, AF_INET6]
+
+        self._host = args.Host.encode('idna').decode()
+        
+
+    def check(self):
+        records = self._gather_records()
+        certificates = self._gather_certificates()
+        return match_tlsa_records(records, certificates)
diff --git a/check_dane/https.py b/check_dane/https.py
new file mode 100644 (file)
index 0000000..c437e47
--- /dev/null
@@ -0,0 +1,88 @@
+#!/usr/bin/python3
+
+from __future__ import print_function
+
+import sys
+import argparse
+import logging
+
+from socket import socket
+
+from check_dane.tlsa import get_tlsa_records, match_tlsa_records
+from check_dane.cert import verify_certificate, add_certificate_options
+from check_dane.abstract import DaneChecker
+
+
+from ssl import SSLContext, PROTOCOL_TLSv1_2, CERT_REQUIRED
+
+
+class HttpsDaneChecker(DaneChecker):
+    def _init_connection(self, family, host, port):
+        connection = self._sslcontext.wrap_socket(socket(family),
+                                                  server_hostname=host)
+        connection.connect((host, port))
+        connection.send(b"HEAD / HTTP/1.1\r\nHost: %s\r\n\r\n" % host.encode())
+        answer = connection.recv(512)
+        logging.debug(answer)
+
+        return connection
+
+
+    @property
+    def port(self):
+        return 443
+
+    
+    def _close_connection(self, connection):
+        connection.close()
+
+        
+    def __init__(self):
+        DaneChecker.__init__(self)
+
+
+    def set_args(self, args):
+        DaneChecker.set_args(self, args)
+        
+        sslcontext = SSLContext(PROTOCOL_TLSv1_2)
+        sslcontext.verify_mode = CERT_REQUIRED
+        sslcontext.load_verify_locations(args.castore)
+
+        self._sslcontext = sslcontext
+
+        
+    def generate_menu(self, argparser):
+        DaneChecker.generate_menu(self, argparser)
+        argparser.add_argument("-p", "--port",
+                               action="store", type=int, default=443,
+                               help="HTTPS port")
+
+        
+
+
+def main():
+    logging.basicConfig(format='%(levelname)5s %(message)s')
+    checker = HttpsDaneChecker()
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument("--verbose", action="store_true")
+    parser.add_argument("--quiet", action="store_true")
+
+    checker.generate_menu(parser)
+    add_certificate_options(parser)
+
+    args = parser.parse_args()
+    checker.set_args(args)
+
+    if args.verbose:
+        logging.getLogger().setLevel(logging.DEBUG)
+    elif args.quiet:
+        logging.getLogger().setLevel(logging.WARNING)
+    else:
+        logging.getLogger().setLevel(logging.INFO)
+    
+    return checker.check()
+
+if __name__ == '__main__':
+    import sys
+    sys.exit(main())
index 3f8b489..9d31b5d 100644 (file)
@@ -14,6 +14,130 @@ try:
 except ImportError:
     RR_TYPE_TLSA = 52
 
+
+
+class TLSARecord:
+    """Class representing a TLSA record"""
+    def __init__(self, usage, selector, matching, payload):
+        self._usage = usage
+        self._selector = selector
+        self._matching = matching
+        self._payload = payload
+
+
+    def match(self, certificate):
+        """Returns true if the certificate is covered by this TLSA record"""
+        if self._selector == 0:
+            verifieddata = certificate
+        elif self._selector == 1:
+            verifieddata = get_spki(certificate)
+        else:
+            # currently only 0 and 1 are assigned
+            sys.stderr.write("Only selectors 0 and 1 supported\n")
+
+        if self._matching == 0:
+            if verifieddata == self._payload:
+                return True
+
+        elif self._matching == 1:
+            if hashlib.sha256(verifieddata).digest() == self._payload:
+                return True
+
+        elif self._matching == 2:
+            if hashlib.sha512(verifieddata).digest() == self._payload:
+                return True
+
+        else:
+            # currently only 0, 1 and 2 are assigned
+            logging.warning("Only matching types 0, 1 and 2 supported\n")
+
+        return False
+
+
+
+    @property
+    def usage(self):
+        """Usage for this TLSA record"""
+        return self._usage
+
+
+    @property
+    def selector(self):
+        """Selector for this record"""
+        return self._selector
+
+
+    @property
+    def matching(self):
+        """Way to match data against certificate"""
+        return self._matching
+
+
+    @property
+    def payload(self):
+        """Payload data of the TLSA record"""
+        return self._payload
+
+
+    def __repr__(self):
+        hexencoder = codecs.getencoder('hex')
+        return '<TLSA %d %d %d %s>' % (self._usage, self._selector, self._matching, hexencoder(self._payload)[0].decode())
+
+
+
+def get_tlsa_records(resolver, name):
+    """Extracts all TLSA records for a given name"""
+
+    logging.debug("searching for TLSA record on %s", name)
+    s, r = resolver.resolve(name, rrtype=RR_TYPE_TLSA)
+    if 0 != s:
+        ub_strerror(s)
+        return
+
+    if r.data is None:
+        logging.warn("No TLSA record returned")
+        return set()
+
+    result = set()
+    for record in r.data.data:
+        hexencoder = codecs.getencoder('hex')
+        usage = ord(record[0])
+        selector = ord(record[1])
+        matching = ord(record[2])
+        data = record[3:]
+        result.add(TLSARecord(usage, selector, matching, data))
+
+    return result
+
+
+def match_tlsa_records(records, certificates):
+    """Returns all TLSA records matching the certificate"""
+
+    usedrecords = set()
+    result = 0
+
+    for certificate in certificates:
+        recfound = False
+
+        for record in records:
+            if record.match(certificate):
+                logging.info("Matched record %s", record)
+                usedrecords.add(record)
+                recfound = True
+
+        if not recfound:
+            logging.error("No TLSA record returned")
+            result = 2
+
+    for record in records:
+        if not record in usedrecords:
+            logging.warn("Unused record %s", record)
+            if result == 0:
+                result = 1
+
+    return result
+
+
 def verify_tlsa_record(resolver, record, certificate):
     logging.debug("searching for TLSA record on %s", record)
     s, r = resolver.resolve(record, rrtype=RR_TYPE_TLSA)
@@ -27,9 +151,9 @@ def verify_tlsa_record(resolver, record, certificate):
 
     for record in r.data.data:
         hexencoder = codecs.getencoder('hex')
-        usage = record[0]
-        selector = record[1]
-        matching = record[2]
+        usage = ord(record[0])
+        selector = ord(record[1])
+        matching = ord(record[2])
         data = record[3:]
 
         if usage != 3:
index 26ef3fd..4bf82e8 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -10,6 +10,9 @@ setup(name='DANE monitoring plugins',
       url='https://git.siccegge.de/?p=dane-monitoring-plugins.git',
       packages=['check_dane'
             ],
-      scripts=['check_dane_smtp'
+      entry_points={
+          'console_scripts': [
+              'check_dane_https = check_dane.https:main',
           ],
+      }
 )