Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions client/internal/dns/mock_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}

// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// Mock implementation - no-op
}

// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
Expand Down
13 changes: 13 additions & 0 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type Server interface {
ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
SetRouteChecker(func(netip.Addr) bool)
}

type nsGroupsByDomain struct {
Expand Down Expand Up @@ -104,6 +105,7 @@ type DefaultServer struct {

statusRecorder *peer.Status
stateManager *statemanager.Manager
routeMatch func(netip.Addr) bool

probeMu sync.Mutex
probeCancel context.CancelFunc
Expand Down Expand Up @@ -229,6 +231,14 @@ func newDefaultServer(
return defaultServer
}

// SetRouteChecker sets the function used by upstream resolvers to determine
// whether an IP is routed through the tunnel.
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
s.mux.Lock()
defer s.mux.Unlock()
s.routeMatch = f
}

// RegisterHandler registers a handler for the given domains with the given priority.
// Any previously registered handler for the same domain and priority will be replaced.
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
Expand Down Expand Up @@ -743,6 +753,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
return
}
handler.routeMatch = s.routeMatch
Comment thread
mlsmaycon marked this conversation as resolved.

for _, ns := range originalNameservers {
if ns == config.ServerIP {
Expand Down Expand Up @@ -852,6 +863,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
if err != nil {
return nil, fmt.Errorf("create upstream resolver: %v", err)
}
handler.routeMatch = s.routeMatch
Comment thread
mlsmaycon marked this conversation as resolved.

for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
Expand Down Expand Up @@ -1036,6 +1048,7 @@ func (s *DefaultServer) addHostRootZone() {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
}
handler.routeMatch = s.routeMatch
Comment thread
mlsmaycon marked this conversation as resolved.

handler.upstreamServers = maps.Keys(hostDNSServers)
handler.deactivate = func(error) {}
Expand Down
1 change: 1 addition & 0 deletions client/internal/dns/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type upstreamResolverBase struct {
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}

type upstreamFailure struct {
Expand Down
8 changes: 5 additions & 3 deletions client/internal/dns/upstream_ios.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} else {
upstreamIP = upstreamIP.Unmap()
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream)
needsPrivate := u.lNet.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
return nil, 0, fmt.Errorf("create private client: %s", err)
}
}

Expand Down
11 changes: 11 additions & 0 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)

e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)

e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})

if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
Expand Down
Loading