Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a max TTL for cached entries #12

Merged
merged 1 commit into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package doh

import (
"context"
"math"
"net"
"strings"
"sync"
Expand All @@ -17,8 +18,9 @@ type Resolver struct {
url string

// RR cache
ipCache map[string]ipAddrEntry
txtCache map[string]txtEntry
ipCache map[string]ipAddrEntry
txtCache map[string]txtEntry
maxCacheTTL time.Duration
}

type ipAddrEntry struct {
Expand All @@ -31,16 +33,43 @@ type txtEntry struct {
expire time.Time
}

func NewResolver(url string) *Resolver {
type Option func(*Resolver) error

// Specifies the maximum time entries are valid in the cache
// A maxCacheTTL of zero is equivalent to `WithCacheDisabled`
func WithMaxCacheTTL(maxCacheTTL time.Duration) Option {
return func(tr *Resolver) error {
tr.maxCacheTTL = maxCacheTTL
return nil
}
}

func WithCacheDisabled() Option {
return func(tr *Resolver) error {
tr.maxCacheTTL = 0
return nil
}
}

func NewResolver(url string, opts ...Option) (*Resolver, error) {
if !strings.HasPrefix(url, "https:") {
url = "https://" + url
}

return &Resolver{
url: url,
ipCache: make(map[string]ipAddrEntry),
txtCache: make(map[string]txtEntry),
r := &Resolver{
url: url,
ipCache: make(map[string]ipAddrEntry),
txtCache: make(map[string]txtEntry),
maxCacheTTL: time.Duration(math.MaxUint32) * time.Second,
}

for _, o := range opts {
if err := o(r); err != nil {
return nil, err
}
}

return r, nil
}

var _ madns.BasicResolver = (*Resolver)(nil)
Expand Down Expand Up @@ -81,7 +110,8 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []ne
}
}

r.cacheIPAddr(domain, result, ttl)
cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL)
r.cacheIPAddr(domain, result, cacheTTL)
return result, nil
}

Expand All @@ -96,7 +126,8 @@ func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, erro
return nil, err
}

r.cacheTXT(domain, result, ttl)
cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL)
r.cacheTXT(domain, result, cacheTTL)
return result, nil
}

Expand All @@ -118,7 +149,7 @@ func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) {
return entry.ips, true
}

func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl time.Duration) {
if ttl == 0 {
return
}
Expand All @@ -127,7 +158,7 @@ func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
defer r.mx.Unlock()

fqdn := dns.Fqdn(domain)
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(time.Duration(ttl) * time.Second)}
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(ttl)}
}

func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
Expand All @@ -148,7 +179,7 @@ func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
return entry.txt, true
}

func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
func (r *Resolver) cacheTXT(domain string, txt []string, ttl time.Duration) {
if ttl == 0 {
return
}
Expand All @@ -157,5 +188,12 @@ func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
defer r.mx.Unlock()

fqdn := dns.Fqdn(domain)
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(time.Duration(ttl) * time.Second)}
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(ttl)}
}

func minTTL(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}
52 changes: 50 additions & 2 deletions resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/miekg/dns"
)
Expand Down Expand Up @@ -76,7 +77,10 @@ func TestLookupIPAddr(t *testing.T) {
})
defer resolver.Close()

r := NewResolver("")
r, err := NewResolver("https://cloudflare-dns.com/dns-query")
if err != nil {
t.Fatal("resolver cannot be initialised")
}
r.url = resolver.URL

ips, err := r.LookupIPAddr(context.Background(), domain)
Expand Down Expand Up @@ -120,7 +124,42 @@ func TestLookupTXT(t *testing.T) {
})
defer resolver.Close()

r := NewResolver("")
r, err := NewResolver("")
if err != nil {
t.Fatal("resolver cannot be initialised")
}
r.url = resolver.URL

txt, err := r.LookupTXT(context.Background(), domain)
if err != nil {
t.Fatal(err)
}
if len(txt) == 0 {
t.Fatal("got no TXT entries")
}

// check the cache
txt2, ok := r.getCachedTXT(domain)
if !ok {
t.Fatal("expected cache to be populated")
}
if !sameTXT(txt, txt2) {
t.Fatal("expected cache to contain the same txt entries")
}
}

func TestLookupCache(t *testing.T) {
domain := "example.com"
resolver := mockDoHResolver(t, map[uint16]*dns.Msg{
dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}),
})
defer resolver.Close()

const cacheTTL = time.Second
r, err := NewResolver("", WithMaxCacheTTL(cacheTTL))
if err != nil {
t.Fatal("resolver cannot be initialised")
}
r.url = resolver.URL

txt, err := r.LookupTXT(context.Background(), domain)
Expand All @@ -140,6 +179,15 @@ func TestLookupTXT(t *testing.T) {
t.Fatal("expected cache to contain the same txt entries")
}

// check cache is empty after its maxTTL
time.Sleep(cacheTTL)
txt2, ok = r.getCachedTXT(domain)
if ok {
t.Fatal("expected cache to be empty")
}
if txt2 != nil {
t.Fatal("expected cache to not contain a txt entry")
}
}

func sameIPs(a, b []net.IPAddr) bool {
Expand Down