Actually implement IPv4 / IPv6
authorChristoph Egger <christoph@christoph-egger.org>
Tue, 30 Aug 2016 11:43:44 +0000 (13:43 +0200)
committerChristoph Egger <christoph@christoph-egger.org>
Tue, 30 Aug 2016 11:44:40 +0000 (13:44 +0200)
check_dane_smtp

index c63cfed..28c6efe 100755 (executable)
@@ -16,12 +16,12 @@ from unbound import ub_ctx
 from check_dane.tlsa import verify_tlsa_record
 from check_dane.cert import verify_certificate, add_certificate_options
 
-def init_connection(sslcontext, args):
+def init_connection(sslcontext, args, family):
     host = args.Host
 
     if args.ssl:
         port = 465 if args.port == 0 else args.port
-        connection = sslcontext.wrap_socket(socket(AF_INET),
+        connection = sslcontext.wrap_socket(socket(family),
                                             server_hostname=host)
         connection.connect((host, port))
         answer = connection.recv(512)
@@ -34,7 +34,8 @@ def init_connection(sslcontext, args):
     else:
         port = 25 if args.port == 0 else args.port
 
-        connection = create_connection((host, port))
+        connection = socket(family=family)
+        connection.connect((host, port))
         answer = connection.recv(512)
         logging.debug(answer)
 
@@ -104,9 +105,9 @@ def main():
                         help="ca certificate bundle")
 
     group = parser.add_mutually_exclusive_group()
-    group.add_argument("-6", "--6", action="store_true", help="check via IPv6 only")
-    group.add_argument("-4", "--4", action="store_true", help="check via IPv4 only")
-    group.add_argument("--64", action="store_false", help="check via IPv4 and IPv6 (default)")
+    group.add_argument("-6", "--6", action="store_true", dest="use6", help="check via IPv6 only")
+    group.add_argument("-4", "--4", action="store_true", dest="use4", help="check via IPv4 only")
+    group.add_argument("--64", action="store_false", dest="use64", help="check via IPv4 and IPv6 (default)")
 
     add_certificate_options(parser)
 
@@ -125,18 +126,30 @@ def main():
     host = args.Host.encode('idna').decode()
 
     sslcontext, resolver = init(args)
-    try:
-        connection = init_connection(sslcontext, args)
-    except ConnectionRefusedError:
-        logging.error("Connection refused")
-        return 2
-
-    retval = verify_certificate(connection.getpeercert(), args)
-    nretval = verify_tlsa_record(resolver, "_%d._tcp.%s" % (port, host),
-                                 connection.getpeercert(binary_form=True))
-    retval = max(retval, nretval)
-
-    close_connection(connection)
+
+    if args.use6:
+        afamilies = [AF_INET6]
+    elif args.use4:
+        afamilies = [AF_INET6]
+    else:
+        afamilies = [AF_INET, AF_INET6]
+
+    retval = 0
+    for afamily in afamilies:
+        try:
+            connection = init_connection(sslcontext, args, afamily)
+        except ConnectionRefusedError:
+            logging.error("Connection refused")
+            return 2
+
+        nretval = verify_certificate(connection.getpeercert(), args)
+        retval = max(retval, nretval)
+        nretval = verify_tlsa_record(resolver, "_%d._tcp.%s" % (port, host),
+                                     connection.getpeercert(binary_form=True))
+        retval = max(retval, nretval)
+
+        close_connection(connection)
+
     return retval