Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 131 additions & 38 deletions client/internal/dns/host_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"strings"
"sync"

"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"

Expand All @@ -22,6 +24,7 @@ import (

const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
Expand All @@ -35,6 +38,14 @@ const (
searchSuffix = "Search"
matchSuffix = "Match"
localSuffix = "Local"

// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
maxDomainsPerResolverEntry = 50

// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
maxDomainBytesPerResolverEntry = 1500
)

type systemConfigurator struct {
Expand Down Expand Up @@ -84,28 +95,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
}

matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing match domains from the system")
err = s.removeKeyFromSystemConfig(matchKey)
if err := s.removeKeysContaining(matchSuffix); err != nil {
log.Warnf("failed to remove old match keys: %v", err)
}
if err != nil {
return fmt.Errorf("add match domains: %w", err)
if len(matchDomains) != 0 {
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
return fmt.Errorf("add match domains: %w", err)
}
}
s.updateState(stateManager)

searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing search domains from the system")
err = s.removeKeyFromSystemConfig(searchKey)
if err := s.removeKeysContaining(searchSuffix); err != nil {
log.Warnf("failed to remove old search keys: %v", err)
}
if err != nil {
return fmt.Errorf("add search domains: %w", err)
if len(searchDomains) != 0 {
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
}
s.updateState(stateManager)

Expand Down Expand Up @@ -149,8 +155,7 @@ func (s *systemConfigurator) restoreHostDNS() error {

func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
return s.discoverExistingKeys()
}

keys := make([]string, 0, len(s.createdKeys))
Expand All @@ -160,6 +165,47 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
return keys
}

// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
func (s *systemConfigurator) discoverExistingKeys() []string {
dnsKeys, err := getSystemDNSKeys()
if err != nil {
log.Errorf("failed to get system DNS keys: %v", err)
return nil
}

var keys []string

for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
if strings.Contains(dnsKeys, key) {
keys = append(keys, key)
}
}

for _, suffix := range []string{searchSuffix, matchSuffix} {
for i := 0; ; i++ {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
if !strings.Contains(dnsKeys, key) {
break
}
keys = append(keys, key)
}
}

return keys
}

// getSystemDNSKeys gets all DNS keys
func getSystemDNSKeys() (string, error) {
command := "list .*DNS\nquit\n"
out, err := runSystemConfigCommand(command)
if err != nil {
return "", err
}
return string(out), nil
}

func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line))
Expand All @@ -184,12 +230,11 @@ func (s *systemConfigurator) addLocalDNS() error {
return nil
}

if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
return fmt.Errorf("add local dns state: %w", err)
}
s.createdKeys[localKey] = struct{}{}

return nil
}
Expand Down Expand Up @@ -280,28 +325,77 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}

func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {
return fmt.Errorf("add dns state: %w", err)
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
func splitDomainsIntoBatches(domains []string) [][]string {
if len(domains) == 0 {
return nil
}

log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
var batches [][]string
var current []string
currentBytes := 0

s.createdKeys[key] = struct{}{}
for _, d := range domains {
domainLen := len(d)
newBytes := currentBytes + domainLen
if currentBytes > 0 {
newBytes++ // space separator
}

return nil
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
batches = append(batches, current)
current = nil
currentBytes = 0
}

current = append(current, d)
if currentBytes > 0 {
currentBytes += 1 + domainLen
} else {
currentBytes = domainLen
}
}

if len(current) > 0 {
batches = append(batches, current)
}

return batches
}

func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
err := s.addDNSState(key, domains, dnsServer, port, false)
if err != nil {
return fmt.Errorf("add dns state: %w", err)
// removeKeysContaining removes all created keys that contain the given substring.
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
Comment thread
mlsmaycon marked this conversation as resolved.
var toRemove []string
for key := range s.createdKeys {
if strings.Contains(key, suffix) {
toRemove = append(toRemove, key)
}
}
var multiErr *multierror.Error
for _, key := range toRemove {
if err := s.removeKeyFromSystemConfig(key); err != nil {
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
}
}
return nberrors.FormatErrorOrNil(multiErr)
}

log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
batches := splitDomainsIntoBatches(domains)

s.createdKeys[key] = struct{}{}
for i, batch := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
domainsStr := strings.Join(batch, " ")

if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
return fmt.Errorf("add dns state for batch %d: %w", i, err)
}

s.createdKeys[key] = struct{}{}
}

log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))

return nil
}
Expand Down Expand Up @@ -364,7 +458,6 @@ func (s *systemConfigurator) flushDNSCache() error {
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
}

log.Info("flushed DNS cache")
return nil
}
Expand Down
Loading
Loading