diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 314af51d99d..3185c151d8c 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -17,14 +17,26 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" ) -const dnsTimeout = 5 * time.Second +const ( + dnsTimeout = 5 * time.Second + defaultTTL = 300 * time.Second + refreshBackoff = 30 * time.Second // wait time after failed refresh attempt +) + +// cachedRecord holds DNS records with their cache timestamp. +type cachedRecord struct { + records []dns.RR + cachedAt time.Time + lastFailedRefresh *time.Time // timestamp of last failed refresh attempt, nil if never failed +} // Resolver caches critical NetBird infrastructure domains type Resolver struct { - records map[dns.Question][]dns.RR + records map[dns.Question]*cachedRecord mgmtDomain *domain.Domain serverDomains *dnsconfig.ServerDomains mutex sync.RWMutex + refreshMutex sync.Mutex // prevents concurrent refresh of the same domain } type ipsResponse struct { @@ -35,7 +47,7 @@ type ipsResponse struct { // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - records: make(map[dns.Question][]dns.RR), + records: make(map[dns.Question]*cachedRecord), } } @@ -60,7 +72,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - records, found := m.records[question] + cached, found := m.records[question] m.mutex.RUnlock() if !found { @@ -68,6 +80,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + // Check if cache entry is stale (TTL expired) + var records []dns.RR + if time.Since(cached.cachedAt) > defaultTTL { + records = m.refreshDomain(question) + } else { + records = cached.records + } + resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false @@ -118,7 +138,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, - Ttl: 300, + Ttl: uint32(defaultTTL.Seconds()), }, A: ip.AsSlice(), } @@ -129,7 +149,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, - Ttl: 300, + Ttl: uint32(defaultTTL.Seconds()), }, AAAA: ip.AsSlice(), } @@ -137,6 +157,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { } } + now := time.Now() m.mutex.Lock() if len(aRecords) > 0 { @@ -145,7 +166,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { Qtype: dns.TypeA, Qclass: dns.ClassINET, } - m.records[aQuestion] = aRecords + m.records[aQuestion] = &cachedRecord{ + records: aRecords, + cachedAt: now, + } } if len(aaaaRecords) > 0 { @@ -154,7 +178,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { Qtype: dns.TypeAAAA, Qclass: dns.ClassINET, } - m.records[aaaaQuestion] = aaaaRecords + m.records[aaaaQuestion] = &cachedRecord{ + records: aaaaRecords, + cachedAt: now, + } } m.mutex.Unlock() @@ -165,6 +192,56 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { return nil } +// refreshDomain refreshes a stale cached domain using DefaultResolver. +// On failure, it returns the stale records to avoid breaking connectivity. +// A backoff mechanism prevents repeated blocking refresh attempts after failures. +func (m *Resolver) refreshDomain(question dns.Question) []dns.RR { + m.refreshMutex.Lock() + defer m.refreshMutex.Unlock() + + // Re-read from map after acquiring lock to check if another goroutine refreshed + m.mutex.RLock() + current, found := m.records[question] + m.mutex.RUnlock() + + if !found { + return nil + } + + // Check if already refreshed by another goroutine + if time.Since(current.cachedAt) <= defaultTTL { + return current.records + } + + // Check if we're in backoff period after a failed refresh + if current.lastFailedRefresh != nil && time.Since(*current.lastFailedRefresh) < refreshBackoff { + return current.records + } + + d, _ := domain.FromString(question.Name) + + if err := m.AddDomain(context.Background(), d); err != nil { + log.Warnf("failed to refresh domain=%s: %v, serving stale cache", d.SafeString(), err) + now := time.Now() + current.lastFailedRefresh = &now + return current.records + } + + m.mutex.RLock() + newCached, found := m.records[question] + m.mutex.RUnlock() + + if !found { + // DNS returned no records for this type, preserve stale with backoff + now := time.Now() + current.lastFailedRefresh = &now + return current.records + } + + log.Infof("refreshed cached domain=%s", d.SafeString()) + return newCached.records +} + func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())