Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/internal/dns/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}

func (d *Resolver) ProbeAvailability() {}
func (d *Resolver) ProbeAvailability(context.Context) {}

// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
Expand Down
60 changes: 54 additions & 6 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ type DefaultServer struct {

statusRecorder *peer.Status
stateManager *statemanager.Manager

probeMu sync.Mutex
probeCancel context.CancelFunc
probeWg sync.WaitGroup
}

type handlerWithStop interface {
dns.Handler
Stop()
ProbeAvailability()
ProbeAvailability(context.Context)
ID() types.HandlerID
}

Expand Down Expand Up @@ -362,7 +366,13 @@ func (s *DefaultServer) DnsIP() netip.Addr {

// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
if s.probeCancel != nil {
s.probeCancel()
}
s.ctxCancel()
s.probeMu.Unlock()
s.probeWg.Wait()
s.shutdownWg.Wait()

s.mux.Lock()
Expand Down Expand Up @@ -479,7 +489,8 @@ func (s *DefaultServer) SearchDomains() []string {
}

// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
// and deactivates the group if no server responds.
// If a previous probe is still running, it will be cancelled before starting a new one.
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
Expand All @@ -492,15 +503,52 @@ func (s *DefaultServer) ProbeAvailability() {
}
}

var wg sync.WaitGroup
s.probeMu.Lock()

// don't start probes on a stopped server
if s.ctx.Err() != nil {
s.probeMu.Unlock()
return
}

// cancel any running probe
if s.probeCancel != nil {
s.probeCancel()
s.probeCancel = nil
}

// wait for the previous probe goroutines to finish while holding
// the mutex so no other caller can start a new probe concurrently
s.probeWg.Wait()
Comment thread
lixmal marked this conversation as resolved.

// start a new probe
probeCtx, probeCancel := context.WithCancel(s.ctx)
s.probeCancel = probeCancel

s.probeWg.Add(1)
defer s.probeWg.Done()

// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
s.mux.Lock()
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
for _, mux := range s.dnsMuxMap {
handlers = append(handlers, mux.handler)
}
s.mux.Unlock()

var wg sync.WaitGroup
for _, handler := range handlers {
wg.Add(1)
go func(mux handlerWithStop) {
go func(h handlerWithStop) {
defer wg.Done()
mux.ProbeAvailability()
}(mux.handler)
h.ProbeAvailability(probeCtx)
}(handler)
}

s.probeMu.Unlock()

wg.Wait()
probeCancel()
}

func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
Expand Down
2 changes: 1 addition & 1 deletion client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ type mockHandler struct {

func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }

type mockService struct{}
Expand Down
61 changes: 40 additions & 21 deletions client/internal/dns/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type upstreamResolverBase struct {
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup

deactivate func(error)
reactivate func()
Expand Down Expand Up @@ -115,6 +116,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()

u.mutex.Lock()
u.wg.Wait()
Comment thread
lixmal marked this conversation as resolved.
u.mutex.Unlock()

}
Comment thread
pappz marked this conversation as resolved.

// ServeDNS handles a DNS request
Expand Down Expand Up @@ -260,16 +266,10 @@ func formatFailures(failures []upstreamFailure) string {

// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
u.mutex.Lock()
defer u.mutex.Unlock()

select {
case <-u.ctx.Done():
return
default:
}

// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
Expand All @@ -279,31 +279,39 @@ func (u *upstreamResolverBase) ProbeAvailability() {
var mu sync.Mutex
var wg sync.WaitGroup

var errors *multierror.Error
var errs *multierror.Error
for _, upstream := range u.upstreamServers {
upstream := upstream

wg.Add(1)
go func() {
go func(upstream netip.AddrPort) {
defer wg.Done()
err := u.testNameserver(upstream, 500*time.Millisecond)
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
if err != nil {
errors = multierror.Append(errors, err)
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}

mu.Lock()
defer mu.Unlock()
success = true
}()
mu.Unlock()
}(upstream)
}

wg.Wait()

select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
Comment thread
pappz marked this conversation as resolved.

// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errors.ErrorOrNil())
u.disable(errs.ErrorOrNil())

if u.statusRecorder == nil {
return
Expand Down Expand Up @@ -339,7 +347,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}

for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream, probeTimeout); err != nil {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
Expand All @@ -364,7 +372,9 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
}

// isTimeout returns true if the given error is a network timeout error.
Expand All @@ -387,7 +397,11 @@ func (u *upstreamResolverBase) disable(err error) {
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
Comment thread
pappz marked this conversation as resolved.
}

func (u *upstreamResolverBase) upstreamServersString() string {
Expand All @@ -398,13 +412,18 @@ func (u *upstreamResolverBase) upstreamServersString() string {
return strings.Join(servers, ", ")
}

func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()

if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}

r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)

_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
return err
}

Expand Down
2 changes: 1 addition & 1 deletion client/internal/dns/upstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true
}

resolver.ProbeAvailability()
resolver.ProbeAvailability(context.TODO())

if !failed {
t.Errorf("expected that resolving was deactivated")
Expand Down
3 changes: 1 addition & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -1315,8 +1315,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {

// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
e.dnsServer.ProbeAvailability()

go e.dnsServer.ProbeAvailability()
return nil
}

Expand Down
Loading