diff --git a/cmd/pebble/main.go b/cmd/pebble/main.go index 2758bae6..e55e50a3 100644 --- a/cmd/pebble/main.go +++ b/cmd/pebble/main.go @@ -1,10 +1,8 @@ package main import ( - "context" "flag" "log" - "net" "net/http" "os" "strconv" @@ -55,10 +53,6 @@ func main() { err := cmd.ReadConfigFile(*configFile, &c) cmd.FailOnError(err, "Reading JSON config file into config structure") - if len(*resolverAddress) > 0 { - setupCustomDNSResolver(*resolverAddress) - } - alternateRoots := 0 alternateRootsVal := os.Getenv("PEBBLE_ALTERNATE_ROOTS") if val, err := strconv.ParseInt(alternateRootsVal, 10, 0); err == nil && val >= 0 { @@ -67,7 +61,7 @@ func main() { db := db.NewMemoryStore() ca := ca.New(logger, db, c.Pebble.OCSPResponderURL, alternateRoots) - va := va.New(logger, c.Pebble.HTTPPort, c.Pebble.TLSPort, *strictMode) + va := va.New(logger, c.Pebble.HTTPPort, c.Pebble.TLSPort, *strictMode, *resolverAddress) wfeImpl := wfe.New(logger, db, va, ca, *strictMode) muxHandler := wfeImpl.Handler() @@ -103,13 +97,3 @@ func main() { muxHandler) cmd.FailOnError(err, "Calling ListenAndServeTLS()") } - -func setupCustomDNSResolver(dnsResolverAddress string) { - net.DefaultResolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - d := net.Dialer{} - return d.DialContext(ctx, "udp", dnsResolverAddress) - }, - } -} diff --git a/go.mod b/go.mod index 59b5fe7a..75cf7e7a 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/letsencrypt/pebble require ( github.com/letsencrypt/challtestsrv v1.1.0 + github.com/miekg/dns v1.1.15 golang.org/x/net v0.0.0-20181207154023-610586996380 // indirect golang.org/x/sys v0.0.0-20181206074257-70b957f3b65e // indirect gopkg.in/square/go-jose.v2 v2.1.9 diff --git a/va/va.go b/va/va.go index edbb3a0e..8be943ae 100644 --- a/va/va.go +++ b/va/va.go @@ -21,6 +21,8 @@ import ( "strings" "time" + "github.com/miekg/dns" + "github.com/letsencrypt/challtestsrv" "github.com/letsencrypt/pebble/acme" "github.com/letsencrypt/pebble/core" @@ -93,28 +95,38 @@ type vaTask struct { } type VAImpl struct { - log *log.Logger - httpPort int - tlsPort int - tasks chan *vaTask - sleep bool - sleepTime int - alwaysValid bool - strict bool + log *log.Logger + httpPort int + tlsPort int + tasks chan *vaTask + sleep bool + sleepTime int + alwaysValid bool + strict bool + customResolverAddr string + dnsClient *dns.Client } func New( log *log.Logger, httpPort, tlsPort int, - strict bool) *VAImpl { + strict bool, customResolverAddr string) *VAImpl { va := &VAImpl{ - log: log, - httpPort: httpPort, - tlsPort: tlsPort, - tasks: make(chan *vaTask, taskQueueSize), - sleep: true, - sleepTime: defaultSleepTime, - strict: strict, + log: log, + httpPort: httpPort, + tlsPort: tlsPort, + tasks: make(chan *vaTask, taskQueueSize), + sleep: true, + sleepTime: defaultSleepTime, + strict: strict, + customResolverAddr: customResolverAddr, + } + + if customResolverAddr != "" { + va.log.Printf("Using custom DNS resolver for ACME challenges: %s", customResolverAddr) + va.dnsClient = new(dns.Client) + } else { + va.log.Print("Using system DNS resolver for ACME challenges") } // Read the PEBBLE_VA_NOSLEEP environment variable string @@ -299,10 +311,7 @@ func (va VAImpl) validateDNS01(task *vaTask) *core.ValidationRecord { ValidatedAt: time.Now(), } - ctx, cancelfunc := context.WithTimeout(context.Background(), validationTimeout) - defer cancelfunc() - - txts, err := net.DefaultResolver.LookupTXT(ctx, challengeSubdomain) + txts, err := va.getTXTEntry(challengeSubdomain) if err != nil { result.Error = acme.UnauthorizedProblem(fmt.Sprintf("Error retrieving TXT records for DNS challenge (%q)", err)) return result @@ -490,6 +499,18 @@ func (va VAImpl) fetchHTTP(identifier string, token string) ([]byte, string, *ac httpRequest.Header.Set("User-Agent", userAgent()) httpRequest.Header.Set("Accept", "*/*") + addrs, err := va.resolveIP(identifier) + + if err != nil { + return nil, url.String(), acme.MalformedProblem( + fmt.Sprintf("Error occurred while resolving URL %q: %q", url.String(), err)) + } + + if len(addrs) == 0 { + return nil, url.String(), acme.MalformedProblem( + fmt.Sprintf("Could not resolve URL %q", url.String())) + } + transport := &http.Transport{ // We don't expect to make multiple requests to a client, so close // connection immediately. @@ -501,6 +522,12 @@ func (va VAImpl) fetchHTTP(identifier string, token string) ([]byte, string, *ac TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, + + // Control specifically which IP will be used for this request + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{} + return dialer.DialContext(ctx, network, net.JoinHostPort(addrs[0], portString)) + }, } client := &http.Client{ @@ -534,6 +561,87 @@ func (va VAImpl) fetchHTTP(identifier string, token string) ([]byte, string, *ac return body, url.String(), nil } +// getTXTEntry fetches TXT entries for the given domain name using the recursive resolver located at +// `va.customResolverAddr`, or the default system resolver if no custom resolver addr is specified +func (va VAImpl) getTXTEntry(name string) ([]string, error) { + ctx, cancelfunc := context.WithTimeout(context.Background(), validationTimeout) + defer cancelfunc() + + if va.customResolverAddr == "" { + return net.DefaultResolver.LookupTXT(ctx, name) + } + + var txts []string + message := new(dns.Msg) + message.SetQuestion(dns.Fqdn(name), dns.TypeTXT) + in, _, err := va.dnsClient.ExchangeContext(ctx, message, va.customResolverAddr) + + if err != nil { + return nil, err + } + + if in.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("DNS lookup for %q returned an unsuccessful response: %q", name, in.Rcode) + } + + for _, record := range in.Answer { + if t, ok := record.(*dns.TXT); ok { + txts = append(txts, t.Txt...) + } + } + + return txts, nil +} + +// resolveIP find all IPs for the given domain name using the recursive resolver located at +// `va.customResolverAddr`, or the default system resolver if no custom resolver addr is specified +func (va VAImpl) resolveIP(name string) ([]string, error) { + ctx, cancelfunc := context.WithTimeout(context.Background(), validationTimeout) + defer cancelfunc() + + if va.customResolverAddr == "" { + return net.DefaultResolver.LookupHost(ctx, name) + } + + // Check if the given name is not already an IP. If it is the case, just return it untouched. + addrs := []string{} + parsed := net.ParseIP(name) + if parsed != nil { + addrs = append(addrs, name) + return addrs, nil + } + + messageAAAA := new(dns.Msg) + messageAAAA.SetQuestion(dns.Fqdn(name), dns.TypeAAAA) + inAAAA, _, err := va.dnsClient.ExchangeContext(ctx, messageAAAA, va.customResolverAddr) + + if err != nil { + return nil, err + } + + for _, record := range inAAAA.Answer { + if t, ok := record.(*dns.AAAA); ok { + addrs = append(addrs, t.AAAA.String()) + } + } + + messageA := new(dns.Msg) + messageA.SetQuestion(dns.Fqdn(name), dns.TypeA) + inA, _, err := va.dnsClient.ExchangeContext(ctx, messageA, va.customResolverAddr) + + if err != nil { + return nil, err + } + + for _, record := range inA.Answer { + if t, ok := record.(*dns.A); ok { + addrs = append(addrs, t.A.String()) + } + } + + return addrs, nil +} + // reverseaddr function is borrowed from net/dnsclient.go[0] and the Go std library. // [0]: https://golang.org/src/net/dnsclient.go func reverseaddr(addr string) string {