Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 85 additions & 8 deletions client/internal/dns/mgmt/mgmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,26 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)

const dnsTimeout = 5 * time.Second
const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second // wait time after failed refresh attempt
)

// cachedRecord holds DNS records with their cache timestamp.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh *time.Time // timestamp of last failed refresh attempt, nil if never failed
}

// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question][]dns.RR
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
refreshMutex sync.Mutex // prevents concurrent refresh of the same domain
}

type ipsResponse struct {
Expand All @@ -35,7 +47,7 @@ type ipsResponse struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
records: make(map[dns.Question]*cachedRecord),
}
}

Expand All @@ -60,14 +72,22 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}

m.mutex.RLock()
records, found := m.records[question]
cached, found := m.records[question]
m.mutex.RUnlock()

if !found {
m.continueToNext(w, r)
return
}

// Check if cache entry is stale (TTL expired)
var records []dns.RR
if time.Since(cached.cachedAt) > defaultTTL {
records = m.refreshDomain(question)
} else {
records = cached.records
}

resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
Expand Down Expand Up @@ -118,7 +138,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
Ttl: uint32(defaultTTL.Seconds()),
},
A: ip.AsSlice(),
}
Expand All @@ -129,14 +149,15 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
Ttl: uint32(defaultTTL.Seconds()),
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}

now := time.Now()
m.mutex.Lock()

if len(aRecords) > 0 {
Expand All @@ -145,7 +166,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
m.records[aQuestion] = &cachedRecord{
records: aRecords,
cachedAt: now,
}
}

if len(aaaaRecords) > 0 {
Expand All @@ -154,7 +178,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
m.records[aaaaQuestion] = &cachedRecord{
records: aaaaRecords,
cachedAt: now,
}
}

m.mutex.Unlock()
Expand All @@ -165,6 +192,56 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return nil
}

// refreshDomain refreshes a stale cached domain using DefaultResolver.
// On failure, it returns the stale records to avoid breaking connectivity.
// A backoff mechanism prevents repeated blocking refresh attempts after failures.
func (m *Resolver) refreshDomain(question dns.Question) []dns.RR {
m.refreshMutex.Lock()
defer m.refreshMutex.Unlock()

// Re-read from map after acquiring lock to check if another goroutine refreshed
m.mutex.RLock()
current, found := m.records[question]
m.mutex.RUnlock()

if !found {
return nil
}

// Check if already refreshed by another goroutine
if time.Since(current.cachedAt) <= defaultTTL {
return current.records
}

// Check if we're in backoff period after a failed refresh
if current.lastFailedRefresh != nil && time.Since(*current.lastFailedRefresh) < refreshBackoff {
return current.records
}

d, _ := domain.FromString(question.Name)

if err := m.AddDomain(context.Background(), d); err != nil {
log.Warnf("failed to refresh domain=%s: %v, serving stale cache", d.SafeString(), err)
now := time.Now()
current.lastFailedRefresh = &now
return current.records
}

m.mutex.RLock()
newCached, found := m.records[question]
m.mutex.RUnlock()

if !found {
// DNS returned no records for this type, preserve stale with backoff
now := time.Now()
current.lastFailedRefresh = &now
return current.records
}

log.Infof("refreshed cached domain=%s", d.SafeString())
return newCached.records
}

func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
Expand Down
Loading