From 7486738d0a4295d6c4bb1af8005b592feac737f9 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 21 Apr 2026 17:12:20 +0200 Subject: [PATCH 1/7] Add TTL-based refresh to mgmt DNS cache via handler chain --- client/internal/dns/handler_chain.go | 89 ++++ client/internal/dns/handler_chain_test.go | 164 +++++++ client/internal/dns/mgmt/mgmt.go | 399 ++++++++++++++---- client/internal/dns/mgmt/mgmt_refresh_test.go | 359 ++++++++++++++++ client/internal/dns/server.go | 1 + 5 files changed, 919 insertions(+), 93 deletions(-) create mode 100644 client/internal/dns/mgmt/mgmt_refresh_test.go diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 6fbdedc5953..1b1be7b54e0 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,7 +1,10 @@ package dns import ( + "context" "fmt" + "math" + "net" "slices" "strconv" "strings" @@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() { } func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + c.dispatch(w, r, math.MaxInt) +} + +// dispatch routes a DNS request through the chain, skipping handlers with +// priority > maxPriority. Shared by ServeDNS and ResolveInternal. +func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) { if len(r.Question) == 0 { return } @@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // Try handlers in priority order for _, entry := range handlers { + if entry.Priority > maxPriority { + continue + } if !c.isHandlerMatch(qname, entry) { continue } @@ -273,6 +285,50 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q cw.response.Len(), meta, time.Since(startTime)) } +// ResolveInternal runs an in-process DNS query against the chain, skipping any +// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt +// cache refresher) that must bypass themselves to avoid loops. Honors ctx +// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own +// (bounded by the invoked handler's internal timeout). +func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) { + if len(r.Question) == 0 { + return nil, fmt.Errorf("empty question") + } + + base := &internalResponseWriter{} + done := make(chan struct{}) + go func() { + c.dispatch(base, r, maxPriority) + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) + } + + if base.response == nil || base.response.Rcode == dns.RcodeRefused { + return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d", + strings.ToLower(r.Question[0].Name), maxPriority) + } + return base.response, nil +} + +// HasRootHandlerAtOrBelow reports whether any "." handler is registered at +// priority ≤ maxPriority. +func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, h := range c.handlers { + if h.Pattern == "." && h.Priority <= maxPriority { + return true + } + } + return false +} + func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { switch { case entry.Pattern == ".": @@ -291,3 +347,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { } } } + +// internalResponseWriter captures a dns.Msg for in-process chain queries. +type internalResponseWriter struct { + response *dns.Msg +} + +func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil } +func (w *internalResponseWriter) LocalAddr() net.Addr { return nil } +func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil } + +// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg +// still surface their answer to ResolveInternal. +func (w *internalResponseWriter) Write(p []byte) (int, error) { + msg := new(dns.Msg) + if err := msg.Unpack(p); err != nil { + return 0, err + } + w.response = msg + return len(p), nil +} + +func (w *internalResponseWriter) Close() error { return nil } +func (w *internalResponseWriter) TsigStatus() error { return nil } + +// TsigTimersOnly is part of dns.ResponseWriter. +func (w *internalResponseWriter) TsigTimersOnly(bool) { + // no-op: in-process queries carry no TSIG state. +} + +// Hijack is part of dns.ResponseWriter. +func (w *internalResponseWriter) Hijack() { + // no-op: in-process queries have no underlying connection to hand off. +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index fa9525069bf..034a760dc3f 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,11 +1,15 @@ package dns_test import ( + "context" + "net" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/test" @@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { }) } } + +// answeringHandler writes a fixed A record to ack the query. Used to verify +// which handler ResolveInternal dispatches to. +type answeringHandler struct { + name string + ip string +} + +func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + _ = w.WriteMsg(resp) +} + +func (h *answeringHandler) String() string { return h.name } + +func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) { + chain := nbdns.NewHandlerChain() + + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + low := &answeringHandler{name: "low", ip: "10.0.0.2"} + + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + chain.AddHandler("example.com.", low, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 1, len(resp.Answer)) + a, ok := resp.Answer[0].(*dns.A) + assert.True(t, ok) + assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream") +} + +func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) { + chain := nbdns.NewHandlerChain() + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + _, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.Error(t, err, "no handler at or below maxPriority should error") +} + +// rawWriteHandler packs a response and calls ResponseWriter.Write directly +// (instead of WriteMsg), exercising the internalResponseWriter.Write path. +type rawWriteHandler struct { + ip string +} + +func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + packed, err := resp.Pack() + if err != nil { + return + } + _, _ = w.Write(packed) +} + +func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) { + chain := nbdns.NewHandlerChain() + chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Answer, 1) + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer") +} + +func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) { + chain := nbdns.NewHandlerChain() + _, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream) + assert.Error(t, err) +} + +// hangingHandler blocks indefinitely until closed, simulating a wedged upstream. +type hangingHandler struct { + block chan struct{} +} + +func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + <-h.block + resp := &dns.Msg{} + resp.SetReply(r) + _ = w.WriteMsg(resp) +} + +func (h *hangingHandler) String() string { return "hangingHandler" } + +func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &hangingHandler{block: make(chan struct{})} + defer close(h.block) + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline") +} + +func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &answeringHandler{name: "h", ip: "10.0.0.1"} + + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain") + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count") + + chain.AddHandler(".", h, nbdns.PriorityMgmtCache) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded") + + chain.AddHandler(".", h, nbdns.PriorityDefault) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included") + + chain.RemoveHandler(".", nbdns.PriorityDefault) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) + + // Primary nsgroup case: root handler lands at PriorityUpstream. + chain.AddHandler(".", h, nbdns.PriorityUpstream) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included") + chain.RemoveHandler(".", nbdns.PriorityUpstream) + + // Fallback case: original /etc/resolv.conf entries land at PriorityFallback. + chain.AddHandler(".", h, nbdns.PriorityFallback) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included") + chain.RemoveHandler(".", nbdns.PriorityFallback) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) +} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 314af51d99d..7dacd49024c 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -2,40 +2,79 @@ package mgmt import ( "context" + "errors" "fmt" "net" - "net/netip" "net/url" + "os" "strings" "sync" + "sync/atomic" "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/shared/management/domain" ) -const dnsTimeout = 5 * time.Second +const ( + dnsTimeout = 5 * time.Second + defaultTTL = 300 * time.Second + refreshBackoff = 30 * time.Second -// Resolver caches critical NetBird infrastructure domains + // envMgmtCacheTTL overrides defaultTTL for integration/dev testing. + envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" +) + +// ChainResolver lets the cache refresh stale entries through the DNS handler +// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the +// system resolver. +type ChainResolver interface { + ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) + HasRootHandlerAtOrBelow(maxPriority int) bool +} + +// cachedRecord holds DNS records plus timestamps used for TTL refresh. +// records and cachedAt are set at construction and treated as immutable; +// lastFailedRefresh and consecFailures are mutable and must be accessed under +// Resolver.mutex. +type cachedRecord struct { + records []dns.RR + cachedAt time.Time + lastFailedRefresh time.Time + consecFailures int +} + +// Resolver caches critical NetBird infrastructure domains. +// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex. type Resolver struct { - records map[dns.Question][]dns.RR + records map[dns.Question]*cachedRecord mgmtDomain *domain.Domain serverDomains *dnsconfig.ServerDomains mutex sync.RWMutex -} -type ipsResponse struct { - ips []netip.Addr - err error + chain ChainResolver + chainMaxPriority int + refreshGroup singleflight.Group + + // refreshing tracks questions whose refresh is running via the OS + // fallback path. A ServeDNS hit for a question in this map indicates + // the OS resolver routed the recursive query back to us (loop). Only + // the OS path arms this so chain-path refreshes don't produce false + // positives. The atomic bool is CAS-flipped once per refresh to + // throttle the warning log. + refreshing map[dns.Question]*atomic.Bool } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - records: make(map[dns.Question][]dns.RR), + records: make(map[dns.Question]*cachedRecord), + refreshing: make(map[dns.Question]*atomic.Bool), } } @@ -44,7 +83,19 @@ func (m *Resolver) String() string { return "MgmtCacheResolver" } -// ServeDNS implements dns.Handler interface. +// SetChainResolver wires the handler chain used to refresh stale cache entries. +// maxPriority caps which handlers may answer refresh queries (typically +// PriorityUpstream, so upstream/default/fallback handlers are consulted and +// mgmt/route/local handlers are skipped). +func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { + m.mutex.Lock() + m.chain = chain + m.chainMaxPriority = maxPriority + m.mutex.Unlock() +} + +// ServeDNS serves cached A/AAAA records. Stale entries are returned +// immediately and refreshed asynchronously (stale-while-revalidate). func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -60,7 +111,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - records, found := m.records[question] + cached, found := m.records[question] + inflight := m.refreshing[question] + var shouldRefresh bool + if found { + stale := time.Since(cached.cachedAt) > cacheTTL() + inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff + shouldRefresh = stale && !inBackoff + } m.mutex.RUnlock() if !found { @@ -68,12 +126,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + if inflight != nil && inflight.CompareAndSwap(false, true) { + log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)", + question.Name) + } + + // Skip scheduling a refresh goroutine if one is already inflight for + // this question; singleflight would dedup anyway but skipping avoids + // a parked goroutine per stale hit under bursty load. + if shouldRefresh && inflight == nil { + m.scheduleRefresh(question) + } + resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - - resp.Answer = append(resp.Answer, records...) + resp.Answer = append(resp.Answer, cached.records...) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -98,101 +167,177 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { } } -// AddDomain manually adds a domain to cache by resolving it. +// AddDomain resolves a domain and stores its A/AAAA records in the cache. +// A family that resolves NODATA (nil err, zero records) evicts any stale +// entry for that qtype. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := lookupIPWithExtraTimeout(ctx, d) - if err != nil { - return err + aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) + + if errA != nil && errAAAA != nil { + return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) } - var aRecords, aaaaRecords []dns.RR - for _, ip := range ips { - if ip.Is4() { - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: ip.AsSlice(), - } - aRecords = append(aRecords, rr) - } else if ip.Is6() { - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: ip.AsSlice(), - } - aaaaRecords = append(aaaaRecords, rr) - } + // Dual NODATA: don't wipe existing entries, let the caller retry. + if errA == nil && errAAAA == nil && len(aRecords) == 0 && len(aaaaRecords) == 0 { + return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } + now := time.Now() m.mutex.Lock() + defer m.mutex.Unlock() - if len(aRecords) > 0 { - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - m.records[aQuestion] = aRecords + m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) + m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) + + log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", + d.SafeString(), len(aRecords), len(aaaaRecords)) + + return nil +} + +// applyFamilyRecords writes records, evicts on NODATA, leaves the cache +// untouched on error. Caller holds m.mutex. +func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { + q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} + switch { + case len(records) > 0: + m.records[q] = &cachedRecord{records: records, cachedAt: now} + case err == nil: + delete(m.records, q) + } +} + +// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per +// unique in-flight key; bursty stale hits share its channel. +func (m *Resolver) scheduleRefresh(question dns.Question) { + key := question.Name + "|" + dns.TypeToString[question.Qtype] + _ = m.refreshGroup.DoChan(key, func() (any, error) { + return nil, m.refreshQuestion(question) + }) +} + +// refreshQuestion replaces the cached records on success, or marks the entry +// failed (arming the backoff) on failure. While this runs, ServeDNS can detect +// a resolver loop by spotting a query for this same question arriving on us. +func (m *Resolver) refreshQuestion(question dns.Question) error { + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) + if err != nil { + m.markRefreshFailed(question) + return fmt.Errorf("parse domain: %w", err) } - if len(aaaaRecords) > 0 { - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, + records, err := m.lookupRecords(ctx, d, question) + if err != nil { + fails := m.markRefreshFailed(question) + logf := log.Warnf + if fails > 1 { + logf = log.Debugf } - m.records[aaaaQuestion] = aaaaRecords + logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", + d.SafeString(), dns.TypeToString[question.Qtype], err, fails) + return err } - m.mutex.Unlock() + // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. + if len(records) == 0 { + m.mutex.Lock() + delete(m.records, question) + m.mutex.Unlock() + log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } - log.Debugf("added domain=%s with %d A records and %d AAAA records", - d.SafeString(), len(aRecords), len(aaaaRecords)) + now := time.Now() + m.mutex.Lock() + if _, stillCached := m.records[question]; !stillCached { + m.mutex.Unlock() + log.Debugf("skipping refresh write for domain=%s type=%s: entry was removed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.records[question] = &cachedRecord{records: records, cachedAt: now} + m.mutex.Unlock() + log.Infof("refreshed mgmt cache domain=%s type=%s", + d.SafeString(), dns.TypeToString[question.Qtype]) return nil } -func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { - log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) - defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) - resultChan := make(chan *ipsResponse, 1) +func (m *Resolver) markRefreshing(question dns.Question) { + m.mutex.Lock() + m.refreshing[question] = &atomic.Bool{} + m.mutex.Unlock() +} - go func() { - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) - resultChan <- &ipsResponse{ - err: err, - ips: ips, - } - }() +func (m *Resolver) clearRefreshing(question dns.Question) { + m.mutex.Lock() + delete(m.refreshing, question) + m.mutex.Unlock() +} - var resp *ipsResponse +// markRefreshFailed arms the backoff and returns the new consecutive-failure +// count so callers can downgrade subsequent failure logs to debug. +func (m *Resolver) markRefreshFailed(question dns.Question) int { + m.mutex.Lock() + defer m.mutex.Unlock() + c, ok := m.records[question] + if !ok { + return 0 + } + c.lastFailedRefresh = time.Now() + c.consecFailures++ + return c.consecFailures +} + +// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let +// callers tell records, NODATA (nil err, no records), and failure apart. +func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() - select { - case <-time.After(dnsTimeout + time.Millisecond*500): - log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) - return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) - case <-ctx.Done(): - return nil, ctx.Err() - case resp = <-resultChan: + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + aRecords, errA = lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) + aaaaRecords, errAAAA = lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) + return } - if resp.err != nil { - return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + // TODO: drop once every supported OS registers a fallback resolver. Safe + // today: no root handler at priority ≤ PriorityUpstream means NetBird is + // not the system resolver, so net.DefaultResolver will not loop back. + aRecords, errA = osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = osLookup(ctx, d, dnsName, dns.TypeAAAA) + return +} + +// lookupRecords resolves a single record type via chain or OS. The OS branch +// arms the loop detector for the duration of its call so that ServeDNS can +// spot the OS resolver routing the recursive query back to us. +func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + return lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) } - return resp.ips, nil + + // TODO: drop once every supported OS registers a fallback resolver. + m.markRefreshing(q) + defer m.clearRefreshing(q) + + return osLookup(ctx, d, q.Name, q.Qtype) } // PopulateFromConfig extracts and caches domains from the client configuration. @@ -224,19 +369,8 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - delete(m.records, aQuestion) - - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - delete(m.records, aaaaQuestion) + delete(m.records, dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}) + delete(m.records, dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}) log.Debugf("removed domain=%s from cache", d.SafeString()) return nil @@ -394,3 +528,82 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } + +func cacheTTL() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return defaultTTL +} + +// lookupViaChain resolves via the handler chain and rewrites each RR to use +// dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache +// target-owned records or upstream TTLs. NODATA returns (nil, nil). +func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { + msg := &dns.Msg{} + msg.SetQuestion(dnsName, qtype) + msg.RecursionDesired = true + + resp, err := chain.ResolveInternal(ctx, msg, maxPriority) + if err != nil { + return nil, fmt.Errorf("chain resolve: %w", err) + } + if resp == nil { + return nil, fmt.Errorf("chain resolve returned nil response") + } + if resp.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) + } + + ttl := uint32(cacheTTL().Seconds()) + var filtered []dns.RR + for _, rr := range resp.Answer { + if rr.Header().Class != dns.ClassINET { + continue + } + switch r := rr.(type) { + case *dns.A: + if qtype != dns.TypeA { + continue + } + filtered = append(filtered, &dns.A{ + Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, + A: append(net.IP(nil), r.A...), + }) + case *dns.AAAA: + if qtype != dns.TypeAAAA { + continue + } + filtered = append(filtered, &dns.AAAA{ + Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, + AAAA: append(net.IP(nil), r.AAAA...), + }) + } + } + return filtered, nil +} + +// osLookup resolves a single family via net.DefaultResolver using resutil, +// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA +// returns (nil, nil). +func osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { + network := resutil.NetworkForQtype(qtype) + if network == "" { + return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) + } + + log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + + result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + if result.Rcode == dns.RcodeSuccess { + return resutil.IPsToRRs(dnsName, result.IPs, uint32(cacheTTL().Seconds())), nil + } + + if result.Err != nil { + return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) + } + return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) +} diff --git a/client/internal/dns/mgmt/mgmt_refresh_test.go b/client/internal/dns/mgmt/mgmt_refresh_test.go new file mode 100644 index 00000000000..8cc24b59a3f --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_refresh_test.go @@ -0,0 +1,359 @@ +package mgmt + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +type fakeChain struct { + mu sync.Mutex + calls map[string]int + answers map[string][]dns.RR + err error + hasRoot bool + onLookup func() +} + +func newFakeChain() *fakeChain { + return &fakeChain{ + calls: map[string]int{}, + answers: map[string][]dns.RR{}, + hasRoot: true, + } +} + +func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.hasRoot +} + +func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) { + f.mu.Lock() + q := msg.Question[0] + key := q.Name + "|" + dns.TypeToString[q.Qtype] + f.calls[key]++ + answers := f.answers[key] + err := f.err + onLookup := f.onLookup + f.mu.Unlock() + + if onLookup != nil { + onLookup() + } + if err != nil { + return nil, err + } + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Answer = answers + return resp, nil +} + +func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) { + f.mu.Lock() + defer f.mu.Unlock() + key := name + "|" + dns.TypeToString[qtype] + hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60} + switch qtype { + case dns.TypeA: + f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}} + case dns.TypeAAAA: + f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}} + } +} + +func (f *fakeChain) callCount(name string, qtype uint16) int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls[name+"|"+dns.TypeToString[qtype]] +} + +// waitFor polls the predicate until it returns true or the deadline passes. +func waitFor(t *testing.T, d time.Duration, fn func() bool) { + t.Helper() + deadline := time.Now().Add(d) + for time.Now().Before(deadline) { + if fn() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("condition not met within %s", d) +} + +func queryA(t *testing.T, r *Resolver, name string) *dns.Msg { + t.Helper() + msg := new(dns.Msg) + msg.SetQuestion(name, dns.TypeA) + w := &test.MockResponseWriter{} + r.ServeDNS(w, msg) + return w.GetLastResponse() +} + +func firstA(t *testing.T, resp *dns.Msg) string { + t.Helper() + require.NotNil(t, resp) + require.Greater(t, len(resp.Answer), 0, "expected at least one answer") + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok, "expected A record") + return a.A.String() +} + +func TestResolver_ServeFresh_NoRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), // fresh + } + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(20 * time.Millisecond) + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh") +} + +func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), // stale + } + + // First query: serves stale immediately. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1 + }) + + // Next query should now return the refreshed IP. + waitFor(t, time.Second, func() bool { + resp := queryA(t, r, "mgmt.example.com.") + return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2" + }) +} + +func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + + var inflight atomic.Int32 + var maxInflight atomic.Int32 + chain.onLookup = func() { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + prev := maxInflight.Load() + if cur <= prev || maxInflight.CompareAndSwap(prev, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide + } + + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + queryA(t, r, "mgmt.example.com.") + }() + } + wg.Wait() + + waitFor(t, 2*time.Second, func() bool { + return inflight.Load() == 0 + }) + + calls := chain.callCount("mgmt.example.com.", dns.TypeA) + assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls) + assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently") +} + +func TestResolver_RefreshFailureArmsBackoff(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.err = errors.New("boom") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + // First stale hit triggers a refresh attempt that fails. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) == 1 + }) + waitFor(t, time.Second, func() bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + c, ok := r.records[q] + return ok && !c.lastFailedRefresh.IsZero() + }) + + // Subsequent stale hits within backoff window should not schedule more refreshes. + for i := 0; i < 10; i++ { + queryA(t, r, "mgmt.example.com.") + } + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes") +} + +func TestResolver_NoRootHandler_SkipsChain(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.hasRoot = false + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + // With hasRoot=false the chain must not be consulted. Use a short + // deadline so the OS fallback returns quickly without waiting on a + // real network call in CI. + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + _, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.") + + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), + "chain must not be used when no root handler is registered at the bound priority") +} + +func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) { + // ServeDNS being invoked for a question while a refresh for that question + // is inflight indicates a resolver loop (OS resolver sent the recursive + // query back to us). The inflightRefresh.loopLoggedOnce flag must be set. + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + // Simulate an inflight refresh. + r.markRefreshing(q) + defer r.clearRefreshing(q) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries") + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + require.NotNil(t, inflight) + assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed") +} + +func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + r.markRefreshing(q) + defer r.clearRefreshing(q) + + // Multiple ServeDNS calls during the same refresh must not re-set the flag + // (CompareAndSwap from false -> true returns true only on the first call). + for range 5 { + queryA(t, r, "mgmt.example.com.") + } + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + assert.True(t, inflight.Load()) +} + +func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + queryA(t, r, "mgmt.example.com.") + + r.mutex.RLock() + _, ok := r.refreshing[q] + r.mutex.RUnlock() + assert.False(t, ok, "no refresh inflight means no loop tracking") +} + +func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2") + r.SetChainResolver(chain, 50) + + require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com"))) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.2", firstA(t, resp)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA)) +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f7865047b50..d4f54dec581 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -212,6 +212,7 @@ func newDefaultServer( ctx, stop := context.WithCancel(ctx) mgmtCacheResolver := mgmt.NewResolver() + mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream) defaultServer := &DefaultServer{ ctx: ctx, From e6c62410ea38c66dcb0bb964793fbf79a85b9a86 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 21 Apr 2026 18:46:36 +0200 Subject: [PATCH 2/7] Update client/internal/dns/handler_chain.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- client/internal/dns/handler_chain.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 1b1be7b54e0..57e7722d440 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -305,7 +305,12 @@ func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPrior select { case <-done: case <-ctx.Done(): - return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) + // Prefer a completed response if dispatch finished concurrently with cancellation. + select { + case <-done: + default: + return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) + } } if base.response == nil || base.response.Rcode == dns.RcodeRefused { From 236e99af630377fba4c69f9ac2739e4ae149a110 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 22 Apr 2026 11:13:00 +0200 Subject: [PATCH 3/7] Parse mgmt cache TTL env var once via sync.OnceValue --- client/internal/dns/mgmt/mgmt.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 7dacd49024c..cef9603667e 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -529,14 +529,14 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } -func cacheTTL() time.Duration { +var cacheTTL = sync.OnceValue(func() time.Duration { if v := os.Getenv(envMgmtCacheTTL); v != "" { if d, err := time.ParseDuration(v); err == nil && d > 0 { return d } } return defaultTTL -} +}) // lookupViaChain resolves via the handler chain and rewrites each RR to use // dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache From 4c25ac674a719c944fd6b2557b479ac448efdfb2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 22 Apr 2026 11:24:16 +0200 Subject: [PATCH 4/7] Clamp served TTL to remaining cache life and guard refresh by identity --- client/internal/dns/mgmt/mgmt.go | 75 +++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index cef9603667e..d6bb3d91a01 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -7,6 +7,7 @@ import ( "net" "net/url" "os" + "slices" "strings" "sync" "sync/atomic" @@ -135,14 +136,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // this question; singleflight would dedup anyway but skipping avoids // a parked goroutine per stale hit under bursty load. if shouldRefresh && inflight == nil { - m.scheduleRefresh(question) + m.scheduleRefresh(question, cached) } resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - resp.Answer = append(resp.Answer, cached.records...) + resp.Answer = cloneRecordsWithTTL(cached.records, responseTTL(cached.cachedAt)) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -213,30 +214,35 @@ func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dn } // scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per -// unique in-flight key; bursty stale hits share its channel. -func (m *Resolver) scheduleRefresh(question dns.Question) { +// unique in-flight key; bursty stale hits share its channel. expected is the +// cachedRecord pointer observed by the caller; the refresh only mutates the +// cache if that pointer is still the one stored, so a stale in-flight refresh +// can't clobber a newer entry written by AddDomain or a competing refresh. +func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) { key := question.Name + "|" + dns.TypeToString[question.Qtype] _ = m.refreshGroup.DoChan(key, func() (any, error) { - return nil, m.refreshQuestion(question) + return nil, m.refreshQuestion(question, expected) }) } // refreshQuestion replaces the cached records on success, or marks the entry // failed (arming the backoff) on failure. While this runs, ServeDNS can detect // a resolver loop by spotting a query for this same question arriving on us. -func (m *Resolver) refreshQuestion(question dns.Question) error { +// expected pins the cache entry observed at schedule time; mutations only apply +// if m.records[question] still points at it. +func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error { ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) if err != nil { - m.markRefreshFailed(question) + m.markRefreshFailed(question, expected) return fmt.Errorf("parse domain: %w", err) } records, err := m.lookupRecords(ctx, d, question) if err != nil { - fails := m.markRefreshFailed(question) + fails := m.markRefreshFailed(question, expected) logf := log.Warnf if fails > 1 { logf = log.Debugf @@ -249,18 +255,24 @@ func (m *Resolver) refreshQuestion(question dns.Question) error { // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. if len(records) == 0 { m.mutex.Lock() - delete(m.records, question) + if m.records[question] == expected { + delete(m.records, question) + m.mutex.Unlock() + log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } m.mutex.Unlock() - log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh", d.SafeString(), dns.TypeToString[question.Qtype]) return nil } now := time.Now() m.mutex.Lock() - if _, stillCached := m.records[question]; !stillCached { + if m.records[question] != expected { m.mutex.Unlock() - log.Debugf("skipping refresh write for domain=%s type=%s: entry was removed during refresh", + log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh", d.SafeString(), dns.TypeToString[question.Qtype]) return nil } @@ -286,11 +298,11 @@ func (m *Resolver) clearRefreshing(question dns.Question) { // markRefreshFailed arms the backoff and returns the new consecutive-failure // count so callers can downgrade subsequent failure logs to debug. -func (m *Resolver) markRefreshFailed(question dns.Question) int { +func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int { m.mutex.Lock() defer m.mutex.Unlock() c, ok := m.records[question] - if !ok { + if !ok || c != expected { return 0 } c.lastFailedRefresh = time.Now() @@ -529,6 +541,37 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } +// responseTTL returns the remaining cache lifetime in seconds (rounded up), +// so downstream resolvers don't cache an answer for longer than we will. +func responseTTL(cachedAt time.Time) uint32 { + remaining := cacheTTL() - time.Since(cachedAt) + if remaining <= 0 { + return 0 + } + return uint32((remaining + time.Second - 1) / time.Second) +} + +// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header +// so mutating the response never touches the cached slice. +func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { + out := make([]dns.RR, 0, len(records)) + for _, rr := range records { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + out = append(out, &cp) + case *dns.AAAA: + cp := *r + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + out = append(out, &cp) + } + } + return out +} + var cacheTTL = sync.OnceValue(func() time.Duration { if v := os.Getenv(envMgmtCacheTTL); v != "" { if d, err := time.ParseDuration(v); err == nil && d > 0 { @@ -570,7 +613,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d } filtered = append(filtered, &dns.A{ Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, - A: append(net.IP(nil), r.A...), + A: slices.Clone(r.A), }) case *dns.AAAA: if qtype != dns.TypeAAAA { @@ -578,7 +621,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d } filtered = append(filtered, &dns.AAAA{ Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, - AAAA: append(net.IP(nil), r.AAAA...), + AAAA: slices.Clone(r.AAAA), }) } } From cdc1ff8fd2a6983cf2513179bd39488591a157e2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 22 Apr 2026 11:40:54 +0200 Subject: [PATCH 5/7] Reject empty AddDomain results, filter chain answers by CNAME owner, dedup A/AAAA clone --- client/internal/dns/mgmt/mgmt.go | 108 ++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 39 deletions(-) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index d6bb3d91a01..03334334f1b 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -31,6 +31,15 @@ const ( envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" ) +var cacheTTL = sync.OnceValue(func() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return defaultTTL +}) + // ChainResolver lets the cache refresh stale entries through the DNS handler // chain instead of net.DefaultResolver, avoiding loopback when NetBird is the // system resolver. @@ -183,8 +192,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) } - // Dual NODATA: don't wipe existing entries, let the caller retry. - if errA == nil && errAAAA == nil && len(aRecords) == 0 && len(aaaaRecords) == 0 { + if len(aRecords) == 0 && len(aaaaRecords) == 0 { + if err := errors.Join(errA, errAAAA); err != nil { + return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) + } return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } @@ -551,36 +562,38 @@ func responseTTL(cachedAt time.Time) uint32 { return uint32((remaining + time.Second - 1) / time.Second) } -// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header -// so mutating the response never touches the cached slice. +// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non +// A/AAAA records return nil. +func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + return &cp + case *dns.AAAA: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + return &cp + } + return nil +} + +// cloneRecordsWithTTL clones A/AAAA records preserving their owner and +// stamping ttl so the response shares no memory with the cached slice. func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { out := make([]dns.RR, 0, len(records)) for _, rr := range records { - switch r := rr.(type) { - case *dns.A: - cp := *r - cp.Hdr.Ttl = ttl - cp.A = slices.Clone(r.A) - out = append(out, &cp) - case *dns.AAAA: - cp := *r - cp.Hdr.Ttl = ttl - cp.AAAA = slices.Clone(r.AAAA) - out = append(out, &cp) + if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil { + out = append(out, cp) } } return out } -var cacheTTL = sync.OnceValue(func() time.Duration { - if v := os.Getenv(envMgmtCacheTTL); v != "" { - if d, err := time.ParseDuration(v); err == nil && d > 0 { - return d - } - } - return defaultTTL -}) - // lookupViaChain resolves via the handler chain and rewrites each RR to use // dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache // target-owned records or upstream TTLs. NODATA returns (nil, nil). @@ -601,31 +614,48 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d } ttl := uint32(cacheTTL().Seconds()) + owners := cnameOwners(dnsName, resp.Answer) var filtered []dns.RR for _, rr := range resp.Answer { - if rr.Header().Class != dns.ClassINET { + h := rr.Header() + if h.Class != dns.ClassINET || h.Rrtype != qtype { + continue + } + if !owners[strings.ToLower(dns.Fqdn(h.Name))] { continue } - switch r := rr.(type) { - case *dns.A: - if qtype != dns.TypeA { + if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { + filtered = append(filtered, cp) + } + } + return filtered, nil +} + +// cnameOwners returns dnsName plus every target reachable by following CNAMEs +// in answer, iterating until fixed point so out-of-order chains resolve. +func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { + owners := map[string]bool{dnsName: true} + for { + added := false + for _, rr := range answer { + cname, ok := rr.(*dns.CNAME) + if !ok { continue } - filtered = append(filtered, &dns.A{ - Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, - A: slices.Clone(r.A), - }) - case *dns.AAAA: - if qtype != dns.TypeAAAA { + name := strings.ToLower(dns.Fqdn(cname.Hdr.Name)) + if !owners[name] { continue } - filtered = append(filtered, &dns.AAAA{ - Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, - AAAA: slices.Clone(r.AAAA), - }) + target := strings.ToLower(dns.Fqdn(cname.Target)) + if !owners[target] { + owners[target] = true + added = true + } + } + if !added { + return owners } } - return filtered, nil } // osLookup resolves a single family via net.DefaultResolver using resutil, From 0b21d7c24c80acf5a1cbc84662a8ad5d39ba8407 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 22 Apr 2026 13:04:37 +0200 Subject: [PATCH 6/7] Move cache TTL to Resolver field for injection in tests --- client/internal/dns/mgmt/mgmt.go | 173 +++++++++--------- client/internal/dns/mgmt/mgmt_refresh_test.go | 49 +++++ client/internal/dns/mgmt/mgmt_test.go | 55 ++++++ 3 files changed, 193 insertions(+), 84 deletions(-) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 03334334f1b..7fffb71a7a7 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -31,15 +31,6 @@ const ( envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" ) -var cacheTTL = sync.OnceValue(func() time.Duration { - if v := os.Getenv(envMgmtCacheTTL); v != "" { - if d, err := time.ParseDuration(v); err == nil && d > 0 { - return d - } - } - return defaultTTL -}) - // ChainResolver lets the cache refresh stale entries through the DNS handler // chain instead of net.DefaultResolver, avoiding loopback when NetBird is the // system resolver. @@ -78,6 +69,8 @@ type Resolver struct { // positives. The atomic bool is CAS-flipped once per refresh to // throttle the warning log. refreshing map[dns.Question]*atomic.Bool + + cacheTTL time.Duration } // NewResolver creates a new management domains cache resolver. @@ -85,6 +78,7 @@ func NewResolver() *Resolver { return &Resolver{ records: make(map[dns.Question]*cachedRecord), refreshing: make(map[dns.Question]*atomic.Bool), + cacheTTL: resolveCacheTTL(), } } @@ -125,7 +119,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { inflight := m.refreshing[question] var shouldRefresh bool if found { - stale := time.Since(cached.cachedAt) > cacheTTL() + stale := time.Since(cached.cachedAt) > m.cacheTTL inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff shouldRefresh = stale && !inBackoff } @@ -152,7 +146,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - resp.Answer = cloneRecordsWithTTL(cached.records, responseTTL(cached.cachedAt)) + resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt)) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -330,16 +324,16 @@ func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName stri m.mutex.RUnlock() if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { - aRecords, errA = lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) - aaaaRecords, errAAAA = lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) + aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) return } // TODO: drop once every supported OS registers a fallback resolver. Safe // today: no root handler at priority ≤ PriorityUpstream means NetBird is // not the system resolver, so net.DefaultResolver will not loop back. - aRecords, errA = osLookup(ctx, d, dnsName, dns.TypeA) - aaaaRecords, errAAAA = osLookup(ctx, d, dnsName, dns.TypeAAAA) + aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) return } @@ -353,14 +347,84 @@ func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Que m.mutex.RUnlock() if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { - return lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) + return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) } // TODO: drop once every supported OS registers a fallback resolver. m.markRefreshing(q) defer m.clearRefreshing(q) - return osLookup(ctx, d, q.Name, q.Qtype) + return m.osLookup(ctx, d, q.Name, q.Qtype) +} + +// lookupViaChain resolves via the handler chain and rewrites each RR to use +// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache +// target-owned records or upstream TTLs. NODATA returns (nil, nil). +func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { + msg := &dns.Msg{} + msg.SetQuestion(dnsName, qtype) + msg.RecursionDesired = true + + resp, err := chain.ResolveInternal(ctx, msg, maxPriority) + if err != nil { + return nil, fmt.Errorf("chain resolve: %w", err) + } + if resp == nil { + return nil, fmt.Errorf("chain resolve returned nil response") + } + if resp.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) + } + + ttl := uint32(m.cacheTTL.Seconds()) + owners := cnameOwners(dnsName, resp.Answer) + var filtered []dns.RR + for _, rr := range resp.Answer { + h := rr.Header() + if h.Class != dns.ClassINET || h.Rrtype != qtype { + continue + } + if !owners[strings.ToLower(dns.Fqdn(h.Name))] { + continue + } + if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { + filtered = append(filtered, cp) + } + } + return filtered, nil +} + +// osLookup resolves a single family via net.DefaultResolver using resutil, +// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA +// returns (nil, nil). +func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { + network := resutil.NetworkForQtype(qtype) + if network == "" { + return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) + } + + log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + + result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + if result.Rcode == dns.RcodeSuccess { + return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil + } + + if result.Err != nil { + return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) + } + return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) +} + +// responseTTL returns the remaining cache lifetime in seconds (rounded up), +// so downstream resolvers don't cache an answer for longer than we will. +func (m *Resolver) responseTTL(cachedAt time.Time) uint32 { + remaining := m.cacheTTL - time.Since(cachedAt) + if remaining <= 0 { + return 0 + } + return uint32((remaining + time.Second - 1) / time.Second) } // PopulateFromConfig extracts and caches domains from the client configuration. @@ -552,16 +616,6 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } -// responseTTL returns the remaining cache lifetime in seconds (rounded up), -// so downstream resolvers don't cache an answer for longer than we will. -func responseTTL(cachedAt time.Time) uint32 { - remaining := cacheTTL() - time.Since(cachedAt) - if remaining <= 0 { - return 0 - } - return uint32((remaining + time.Second - 1) / time.Second) -} - // cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non // A/AAAA records return nil. func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { @@ -594,43 +648,6 @@ func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { return out } -// lookupViaChain resolves via the handler chain and rewrites each RR to use -// dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache -// target-owned records or upstream TTLs. NODATA returns (nil, nil). -func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { - msg := &dns.Msg{} - msg.SetQuestion(dnsName, qtype) - msg.RecursionDesired = true - - resp, err := chain.ResolveInternal(ctx, msg, maxPriority) - if err != nil { - return nil, fmt.Errorf("chain resolve: %w", err) - } - if resp == nil { - return nil, fmt.Errorf("chain resolve returned nil response") - } - if resp.Rcode != dns.RcodeSuccess { - return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) - } - - ttl := uint32(cacheTTL().Seconds()) - owners := cnameOwners(dnsName, resp.Answer) - var filtered []dns.RR - for _, rr := range resp.Answer { - h := rr.Header() - if h.Class != dns.ClassINET || h.Rrtype != qtype { - continue - } - if !owners[strings.ToLower(dns.Fqdn(h.Name))] { - continue - } - if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { - filtered = append(filtered, cp) - } - } - return filtered, nil -} - // cnameOwners returns dnsName plus every target reachable by following CNAMEs // in answer, iterating until fixed point so out-of-order chains resolve. func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { @@ -658,25 +675,13 @@ func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { } } -// osLookup resolves a single family via net.DefaultResolver using resutil, -// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA -// returns (nil, nil). -func osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { - network := resutil.NetworkForQtype(qtype) - if network == "" { - return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) - } - - log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) - defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) - - result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) - if result.Rcode == dns.RcodeSuccess { - return resutil.IPsToRRs(dnsName, result.IPs, uint32(cacheTTL().Seconds())), nil - } - - if result.Err != nil { - return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) +// resolveCacheTTL reads the cache TTL override env var; invalid or empty +// values fall back to defaultTTL. Called once per Resolver from NewResolver. +func resolveCacheTTL() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } } - return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) + return defaultTTL } diff --git a/client/internal/dns/mgmt/mgmt_refresh_test.go b/client/internal/dns/mgmt/mgmt_refresh_test.go index 8cc24b59a3f..9faa5a0b88b 100644 --- a/client/internal/dns/mgmt/mgmt_refresh_test.go +++ b/client/internal/dns/mgmt/mgmt_refresh_test.go @@ -112,6 +112,55 @@ func firstA(t *testing.T, resp *dns.Msg) string { return a.A.String() } +func TestResolver_CacheTTLGatesRefresh(t *testing.T) { + // Same cached entry age, different cacheTTL values: the shorter TTL must + // trigger a background refresh, the longer one must not. Proves that the + // per-Resolver cacheTTL field actually drives the stale decision. + cachedAt := time.Now().Add(-100 * time.Millisecond) + + newRec := func() *cachedRecord { + return &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: cachedAt, + } + } + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + + t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = 10 * time.Millisecond + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount(q.Name, dns.TypeA) >= 1 + }) + }) + + t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = time.Hour + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh") + }) +} + func TestResolver_ServeFresh_NoRefresh(t *testing.T) { r := NewResolver() chain := newFakeChain() diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 9e8a746f3dc..276cbba0ab8 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -6,6 +6,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) { assert.False(t, resolver.MatchSubdomains()) } +func TestResolveCacheTTL(t *testing.T) { + tests := []struct { + name string + value string + want time.Duration + }{ + {"unset falls back to default", "", defaultTTL}, + {"valid duration", "45s", 45 * time.Second}, + {"valid minutes", "2m", 2 * time.Minute}, + {"malformed falls back to default", "not-a-duration", defaultTTL}, + {"zero falls back to default", "0s", defaultTTL}, + {"negative falls back to default", "-5s", defaultTTL}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envMgmtCacheTTL, tc.value) + got := resolveCacheTTL() + assert.Equal(t, tc.want, got, "parsed TTL should match") + }) + } +} + +func TestNewResolver_CacheTTLFromEnv(t *testing.T) { + t.Setenv(envMgmtCacheTTL, "7s") + r := NewResolver() + assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env") +} + +func TestResolver_ResponseTTL(t *testing.T) { + now := time.Now() + tests := []struct { + name string + cacheTTL time.Duration + cachedAt time.Time + wantMin uint32 + wantMax uint32 + }{ + {"fresh entry returns full TTL", 60 * time.Second, now, 59, 60}, + {"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31}, + {"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0}, + {"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := &Resolver{cacheTTL: tc.cacheTTL} + got := r.responseTTL(tc.cachedAt) + assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin") + assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax") + }) + } +} + func TestResolver_ExtractDomainFromURL(t *testing.T) { tests := []struct { name string From 19c1ba7cb780bfb6a4512970d0c4388f7633a496 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 22 Apr 2026 13:57:58 +0200 Subject: [PATCH 7/7] Demote fails=0 refresh log, clear refreshing markers on RemoveDomain --- client/internal/dns/mgmt/mgmt.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 7fffb71a7a7..988e427fb59 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -249,7 +249,7 @@ func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord if err != nil { fails := m.markRefreshFailed(question, expected) logf := log.Warnf - if fails > 1 { + if fails == 0 || fails > 1 { logf = log.Debugf } logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", @@ -456,8 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - delete(m.records, dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}) - delete(m.records, dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}) + qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} + qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} + delete(m.records, qA) + delete(m.records, qAAAA) + delete(m.refreshing, qA) + delete(m.refreshing, qAAAA) log.Debugf("removed domain=%s from cache", d.SafeString()) return nil