Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DNS RR cache #3

Merged
merged 3 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,74 +62,88 @@ func doRequest(ctx context.Context, url string, m *dns.Msg) (*dns.Msg, error) {
return r, nil
}

func doRequestA(ctx context.Context, url string, domain string) ([]net.IPAddr, error) {
func doRequestA(ctx context.Context, url string, domain string) ([]net.IPAddr, uint32, error) {
fqdn := dns.Fqdn(domain)

m := new(dns.Msg)
m.SetQuestion(fqdn, dns.TypeA)

r, err := doRequest(ctx, url, m)
if err != nil {
return nil, err
return nil, 0, err
}

var ttl uint32
result := make([]net.IPAddr, 0, len(r.Answer))
for _, rr := range r.Answer {
switch v := rr.(type) {
case *dns.A:
result = append(result, net.IPAddr{IP: v.A})
if ttl == 0 || v.Hdr.Ttl < ttl {
ttl = v.Hdr.Ttl
}
default:
log.Warnf("unexpected DNS resource record %+v", rr)
}
}

return result, nil
return result, ttl, nil
}

func doRequestAAAA(ctx context.Context, url string, domain string) ([]net.IPAddr, error) {
func doRequestAAAA(ctx context.Context, url string, domain string) ([]net.IPAddr, uint32, error) {
fqdn := dns.Fqdn(domain)

m := new(dns.Msg)
m.SetQuestion(fqdn, dns.TypeAAAA)

r, err := doRequest(ctx, url, m)
if err != nil {
return nil, err
return nil, 0, err
}

var ttl uint32
result := make([]net.IPAddr, 0, len(r.Answer))
for _, rr := range r.Answer {
switch v := rr.(type) {
case *dns.AAAA:
result = append(result, net.IPAddr{IP: v.AAAA})
if ttl == 0 || v.Hdr.Ttl < ttl {
ttl = v.Hdr.Ttl
}

default:
log.Warnf("unexpected DNS resource record %+v", rr)
}
}

return result, nil
return result, ttl, nil
}

func doRequestTXT(ctx context.Context, url string, domain string) ([]string, error) {
func doRequestTXT(ctx context.Context, url string, domain string) ([]string, uint32, error) {
vyzo marked this conversation as resolved.
Show resolved Hide resolved
fqdn := dns.Fqdn(domain)

m := new(dns.Msg)
m.SetQuestion(fqdn, dns.TypeTXT)

r, err := doRequest(ctx, url, m)
if err != nil {
return nil, err
return nil, 0, err
}

var ttl uint32
var result []string
for _, rr := range r.Answer {
switch v := rr.(type) {
case *dns.TXT:
result = append(result, v.Txt...)
if ttl == 0 || v.Hdr.Ttl < ttl {
ttl = v.Hdr.Ttl
}

default:
log.Warnf("unexpected DNS resource record %+v", rr)
}
}

return result, nil
return result, ttl, nil
}
117 changes: 111 additions & 6 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,158 @@ import (
"context"
"net"
"strings"
"sync"
"time"

"github.com/miekg/dns"

madns "github.com/multiformats/go-multiaddr-dns"
)

type Resolver struct {
sync.RWMutex
url string

// RR cache
ipCache map[string]ipAddrEntry
txtCache map[string]txtEntry
}

type ipAddrEntry struct {
ips []net.IPAddr
expire time.Time
}

type txtEntry struct {
txt []string
expire time.Time
}

func NewResolver(url string) *Resolver {
if !strings.HasPrefix(url, "https://") {
url = "https://" + url
}

return &Resolver{url: url}
return &Resolver{
url: url,
ipCache: make(map[string]ipAddrEntry),
txtCache: make(map[string]txtEntry),
}
}

var _ madns.BasicResolver = (*Resolver)(nil)

func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []net.IPAddr, err error) {
result, ok := r.getCachedIPAddr(domain)
if ok {
return result, nil
}

type response struct {
ips []net.IPAddr
ttl uint32
err error
}

resch := make(chan response, 2)
go func() {
ip4, err := doRequestA(ctx, r.url, domain)
resch <- response{ip4, err}
ip4, ttl, err := doRequestA(ctx, r.url, domain)
resch <- response{ip4, ttl, err}
}()

go func() {
ip6, err := doRequestAAAA(ctx, r.url, domain)
resch <- response{ip6, err}
ip6, ttl, err := doRequestAAAA(ctx, r.url, domain)
resch <- response{ip6, ttl, err}
}()

var ttl uint32
for i := 0; i < 2; i++ {
r := <-resch
if r.err != nil {
return nil, r.err
}

result = append(result, r.ips...)
if ttl == 0 || r.ttl < ttl {
ttl = r.ttl
}
}

r.cacheIPAddr(domain, result, ttl)
return result, nil
}

func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, error) {
return doRequestTXT(ctx, r.url, domain)
result, ok := r.getCachedTXT(domain)
if ok {
return result, nil
}

result, ttl, err := doRequestTXT(ctx, r.url, domain)
if err != nil {
return nil, err
}

r.cacheTXT(domain, result, ttl)
return result, nil
}

func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) {
r.RLock()
defer r.RUnlock()

fqdn := dns.Fqdn(domain)
entry, ok := r.ipCache[fqdn]
if !ok {
return nil, false
}

if time.Now().After(entry.expire) {
delete(r.ipCache, fqdn)
return nil, false
}

return entry.ips, true
}

func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
if ttl == 0 {
return
}

r.Lock()
defer r.Unlock()

fqdn := dns.Fqdn(domain)
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(time.Duration(ttl) * time.Second)}
}

func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
r.RLock()
defer r.RUnlock()

fqdn := dns.Fqdn(domain)
entry, ok := r.txtCache[fqdn]
if !ok {
return nil, false
}

if time.Now().After(entry.expire) {
delete(r.txtCache, fqdn)
return nil, false
}

return entry.txt, true
}

func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
if ttl == 0 {
return
}

r.Lock()
defer r.Unlock()

fqdn := dns.Fqdn(domain)
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(time.Duration(ttl) * time.Second)}
}
57 changes: 55 additions & 2 deletions resolver_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
package doh

import (
"bytes"
"context"
"net"
"testing"
)

func TestLookupIPAddr(t *testing.T) {
r := NewResolver("https://cloudflare-dns.com/dns-query")

ips, err := r.LookupIPAddr(context.Background(), "libp2p.io")
domain := "libp2p.io"
ips, err := r.LookupIPAddr(context.Background(), domain)
if err != nil {
t.Fatal(err)
}
if len(ips) == 0 {
t.Fatal("got no IPs")
}

// check that we got both IPv4 and IPv6 addrs
var got4, got6 bool
for _, ip := range ips {
if len(ip.IP.To4()) == 4 {
Expand All @@ -29,16 +34,64 @@ func TestLookupIPAddr(t *testing.T) {
if !got6 {
t.Fatal("got no IPv6 addresses")
}

// check the cache
ips2, ok := r.getCachedIPAddr(domain)
if !ok {
t.Fatal("expected cache to be populated")
}
if !sameIPs(ips, ips2) {
t.Fatal("expected cache to contain the same addrs")
}
}

func TestLookupTXT(t *testing.T) {
r := NewResolver("https://cloudflare-dns.com/dns-query")

txt, err := r.LookupTXT(context.Background(), "_dnsaddr.bootstrap.libp2p.io")
domain := "_dnsaddr.bootstrap.libp2p.io"
txt, err := r.LookupTXT(context.Background(), domain)
if err != nil {
t.Fatal(err)
}
if len(txt) == 0 {
t.Fatal("got no TXT entries")
}

// check the cache
txt2, ok := r.getCachedTXT(domain)
if !ok {
t.Fatal("expected cache to be populated")
}
if !sameTXT(txt, txt2) {
t.Fatal("expected cache to contain the same txt entries")
}

}

func sameIPs(a, b []net.IPAddr) bool {
if len(a) != len(b) {
return false
}

for i := range a {
if !bytes.Equal(a[i].IP, b[i].IP) {
return false
}
}

return true
}

func sameTXT(a, b []string) bool {
if len(a) != len(b) {
return false
}

for i := range a {
if a[i] != b[i] {
return false
}
}

return true
}