From b1aa9622336d1dd6660f3ce314ce3a67a783b7c3 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Wed, 19 Jul 2017 12:41:43 -0700 Subject: [PATCH] transport: use reverse lookup to match wildcard DNS SAN Fixes #8268 --- pkg/transport/listener_tls.go | 63 +++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/pkg/transport/listener_tls.go b/pkg/transport/listener_tls.go index 545e0c43db3..6f1600945cc 100644 --- a/pkg/transport/listener_tls.go +++ b/pkg/transport/listener_tls.go @@ -21,6 +21,7 @@ import ( "fmt" "io/ioutil" "net" + "strings" "sync" ) @@ -206,20 +207,62 @@ func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string } } if len(cert.DNSNames) > 0 { - for _, dns := range cert.DNSNames { - addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns) - if lerr != nil { - continue + ok, err := isHostInDNS(ctx, h, cert.DNSNames) + if ok { + return nil + } + errStr := "" + if err != nil { + errStr = " (" + err.Error() + ")" + } + return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames) + } + return nil +} + +func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) { + // reverse lookup + wildcards, names := []string{}, []string{} + for _, dns := range dnsNames { + if strings.HasPrefix(dns, "*.") { + wildcards = append(wildcards, dns[1:]) + } else { + names = append(names, dns) + } + } + lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host) + for _, name := range lnames { + // strip trailing '.' from PTR record + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + for _, wc := range wildcards { + if strings.HasSuffix(name, wc) { + return true, nil } - for _, addr := range addrs { - if addr == h { - return nil - } + } + for _, n := range names { + if n == name { + return true, nil } } - return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames) } - return nil + err = lerr + + // forward lookup + for _, dns := range names { + addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns) + if lerr != nil { + err = lerr + continue + } + for _, addr := range addrs { + if addr == host { + return true, nil + } + } + } + return false, err } func (l *tlsListener) Close() error {