]> git.siccegge.de Git - dane-monitoring-plugins.git/blobdiff - check_dane_smtp
Actually implement IPv4 / IPv6
[dane-monitoring-plugins.git] / check_dane_smtp
index 33ee6a93b71c8ae4a812392413a12ced06d5f47d..28c6efeeac6fd5ebade43b3726736eba44162c4d 100755 (executable)
@@ -6,39 +6,61 @@ from __future__ import print_function
 
 import sys
 import argparse
+import logging
 
 from socket import socket, AF_INET6, AF_INET, create_connection
-from ssl import SSLContext, PROTOCOL_TLSv1_2, CERT_REQUIRED, cert_time_to_seconds, SSLError, CertificateError, create_default_context
-from unbound import ub_ctx, idn2dname, ub_strerror
+from ssl import SSLError, CertificateError, SSLContext
+from ssl import PROTOCOL_TLSv1_2, CERT_REQUIRED
+from unbound import ub_ctx
 
 from check_dane.tlsa import verify_tlsa_record
+from check_dane.cert import verify_certificate, add_certificate_options
 
-def init_connection(sslcontext, args):
+def init_connection(sslcontext, args, family):
     host = args.Host
 
     if args.ssl:
         port = 465 if args.port == 0 else args.port
-        connection = context.wrap_socket(socket(AF_INET),
-                                         server_hostname=host)
-        connection.connect(host, port)
+        connection = sslcontext.wrap_socket(socket(family),
+                                            server_hostname=host)
+        connection.connect((host, port))
+        answer = connection.recv(512)
+        logging.debug(answer)
+
+        connection.send(b"EHLO localhost\r\n")
+        answer = connection.recv(512)
+        logging.debug(answer)
 
     else:
         port = 25 if args.port == 0 else args.port
-        connection = create_connection((host, port))
-        print(connection.recv(512))
+
+        connection = socket(family=family)
+        connection.connect((host, port))
+        answer = connection.recv(512)
+        logging.debug(answer)
+
         connection.send(b"EHLO localhost\r\n")
-        print(connection.recv(512))
+        answer = connection.recv(512)
+        logging.debug(answer)
+
         connection.send(b"STARTTLS\r\n")
-        print(connection.recv(512))
+        answer = connection.recv(512)
+        logging.debug(answer)
+
         connection = sslcontext.wrap_socket(connection, server_hostname=host)
         connection.do_handshake()
 
+        connection.send(b"EHLO localhost\r\n")
+        answer = connection.recv(512)
+        logging.debug(answer)
+
     return connection
 
 
 def close_connection(connection):
     connection.send(b"QUIT\r\n")
-    print(connection.recv(512))
+    answer = connection.recv(512)
+    logging.debug(answer)
 
 
 def init(args):
@@ -53,9 +75,12 @@ def init(args):
 
 
 def main():
+    logging.basicConfig(format='%(levelname)5s %(message)s')
     parser = argparse.ArgumentParser()
     parser.add_argument("Host")
 
+    parser.add_argument("--verbose", action="store_true")
+    parser.add_argument("--quiet", action="store_true")
     parser.add_argument("-p", "--port",
                         action="store", type=int, default=0,
                         help="SMTP port")
@@ -80,20 +105,53 @@ def main():
                         help="ca certificate bundle")
 
     group = parser.add_mutually_exclusive_group()
-    group.add_argument("-6", "--6", action="store_true", help="check via IPv6 only")
-    group.add_argument("-4", "--4", action="store_true", help="check via IPv4 only")
-    group.add_argument("--64", action="store_false", help="check via IPv4 and IPv6 (default)")
+    group.add_argument("-6", "--6", action="store_true", dest="use6", help="check via IPv6 only")
+    group.add_argument("-4", "--4", action="store_true", dest="use4", help="check via IPv4 only")
+    group.add_argument("--64", action="store_false", dest="use64", help="check via IPv4 and IPv6 (default)")
+
+    add_certificate_options(parser)
 
     args = parser.parse_args()
+
+    if args.verbose:
+        logging.getLogger().setLevel(logging.DEBUG)
+    elif args.quiet:
+        logging.getLogger().setLevel(logging.WARNING)
+    else:
+        logging.getLogger().setLevel(logging.INFO)
+
+    port = args.port
+    if port == 0:
+        port = 465 if args.ssl else 25
+    host = args.Host.encode('idna').decode()
+
     sslcontext, resolver = init(args)
-    print(args)
 
-    connection = init_connection(sslcontext, args)
+    if args.use6:
+        afamilies = [AF_INET6]
+    elif args.use4:
+        afamilies = [AF_INET6]
+    else:
+        afamilies = [AF_INET, AF_INET6]
+
+    retval = 0
+    for afamily in afamilies:
+        try:
+            connection = init_connection(sslcontext, args, afamily)
+        except ConnectionRefusedError:
+            logging.error("Connection refused")
+            return 2
+
+        nretval = verify_certificate(connection.getpeercert(), args)
+        retval = max(retval, nretval)
+        nretval = verify_tlsa_record(resolver, "_%d._tcp.%s" % (port, host),
+                                     connection.getpeercert(binary_form=True))
+        retval = max(retval, nretval)
 
-    verify_tlsa_record(resolver, "_25._tcp.%s" % args.Host, connection.getpeercert(binary_form=True))
+        close_connection(connection)
 
-    close_connection(connection)
+    return retval
 
 
 if __name__ == '__main__':
-   main()
+    sys.exit(main())