diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b3908f16313..cf41ea6d7fb 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -23,8 +23,8 @@ import ( ) const ( - netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" - netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS" + netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s-%s/DNS" + netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%s-%d/DNS" globalIPv4State = "State:/Network/Global/IPv4" primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS" keySupplementalMatchDomains = "SupplementalMatchDomains" @@ -51,14 +51,19 @@ const ( type systemConfigurator struct { createdKeys map[string]struct{} systemDNSSettings SystemDNSSettings + interfaceName string mu sync.RWMutex origNameservers []netip.Addr } -func newHostManager() (*systemConfigurator, error) { +func newHostManager(interfaceName string) (*systemConfigurator, error) { + if interfaceName == "" { + return nil, fmt.Errorf("interfaceName must not be empty") + } return &systemConfigurator{ - createdKeys: make(map[string]struct{}), + createdKeys: make(map[string]struct{}), + interfaceName: interfaceName, }, nil } @@ -67,6 +72,12 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { + if err := stateManager.UpdateState(&ShutdownState{ + InterfaceName: s.interfaceName, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } + var ( searchDomains []string matchDomains []string @@ -123,7 +134,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { - if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + if err := stateManager.UpdateState(&ShutdownState{ + InterfaceName: s.interfaceName, + CreatedKeys: maps.Keys(s.createdKeys), + }); err != nil { log.Errorf("failed to update shutdown state: %s", err) } } @@ -167,6 +181,7 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string { // discoverExistingKeys probes scutil for all NetBird DNS keys that may exist. // This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown). +// It also discovers legacy-format keys from older versions for upgrade migration. func (s *systemConfigurator) discoverExistingKeys() []string { dnsKeys, err := getSystemDNSKeys() if err != nil { @@ -176,16 +191,18 @@ func (s *systemConfigurator) discoverExistingKeys() []string { var keys []string + // Current format: interface-scoped named keys for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} { - key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix) + key := getKeyWithInput(netbirdDNSStateKeyFormat, s.interfaceName, suffix) if strings.Contains(dnsKeys, key) { keys = append(keys, key) } } + // Current format: interface-scoped indexed keys for _, suffix := range []string{searchSuffix, matchSuffix} { for i := 0; ; i++ { - key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i) + key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, s.interfaceName, suffix, i) if !strings.Contains(dnsKeys, key) { break } @@ -193,6 +210,25 @@ func (s *systemConfigurator) discoverExistingKeys() []string { } } + // Legacy format: non-interface-scoped keys from older versions + for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} { + legacyKey := fmt.Sprintf("State:/Network/Service/NetBird-%s/DNS", suffix) + if strings.Contains(dnsKeys, legacyKey) { + log.Infof("discovered legacy DNS key (no interface scope): %s", legacyKey) + keys = append(keys, legacyKey) + } + } + for _, suffix := range []string{searchSuffix, matchSuffix} { + for i := 0; ; i++ { + legacyKey := fmt.Sprintf("State:/Network/Service/NetBird-%s-%d/DNS", suffix, i) + if !strings.Contains(dnsKeys, legacyKey) { + break + } + log.Infof("discovered legacy indexed DNS key (no interface scope): %s", legacyKey) + keys = append(keys, legacyKey) + } + } + return keys } @@ -224,7 +260,7 @@ func (s *systemConfigurator) addLocalDNS() error { return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } - localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, s.interfaceName, localSuffix) if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") return nil @@ -258,7 +294,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { if err != nil || primaryServiceKey == "" { return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err) } - dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey) + dnsServiceKey := fmt.Sprintf(primaryServiceStateKeyFormat, primaryServiceKey) line := buildCommandLine("show", dnsServiceKey, "") stdinCommands := wrapCommand(line) @@ -385,7 +421,7 @@ func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, batches := splitDomainsIntoBatches(domains) for i, batch := range batches { - key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i) + key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, s.interfaceName, suffix, i) domainsStr := strings.Join(batch, " ") if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil { @@ -469,8 +505,8 @@ func (s *systemConfigurator) restoreUncleanShutdownDNS() error { return nil } -func getKeyWithInput(format, key string) string { - return fmt.Sprintf(format, key) +func getKeyWithInput(format, iface, key string) string { + return fmt.Sprintf(format, iface, key) } func buildAddCommandLine(key, value string) string { diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go index 94d020c39eb..7aaea40363e 100644 --- a/client/internal/dns/host_darwin_test.go +++ b/client/internal/dns/host_darwin_test.go @@ -6,6 +6,7 @@ import ( "bufio" "bytes" "context" + "encoding/json" "fmt" "net/netip" "os/exec" @@ -19,6 +20,8 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) +const testInterfaceName = "utun999" + func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { if testing.Short() { t.Skip("skipping scutil integration test in short mode") @@ -35,7 +38,8 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { }() configurator := &systemConfigurator{ - createdKeys: make(map[string]struct{}), + createdKeys: make(map[string]struct{}), + interfaceName: testInterfaceName, } config := HostDNSConfig{ @@ -52,7 +56,7 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { require.NoError(t, sm.PersistState(context.Background())) - localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, testInterfaceName, localSuffix) // Collect all created keys for cleanup verification createdKeys := make([]string, 0, len(configurator.createdKeys)) @@ -274,7 +278,8 @@ func TestMatchDomainBatching(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { configurator := &systemConfigurator{ - createdKeys: make(map[string]struct{}), + createdKeys: make(map[string]struct{}), + interfaceName: testInterfaceName, } defer func() { @@ -293,7 +298,7 @@ func TestMatchDomainBatching(t *testing.T) { // Read back all domains from all batched keys var got []string for i := range batches { - key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i) + key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, testInterfaceName, matchSuffix, i) exists, err := checkDNSKeyExists(key) require.NoError(t, err) require.True(t, exists, "key %s should exist", key) @@ -330,7 +335,8 @@ func removeTestDNSKey(key string) error { func TestGetOriginalNameservers(t *testing.T) { configurator := &systemConfigurator{ - createdKeys: make(map[string]struct{}), + createdKeys: make(map[string]struct{}), + interfaceName: testInterfaceName, origNameservers: []netip.Addr{ netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1"), @@ -345,7 +351,8 @@ func TestGetOriginalNameservers(t *testing.T) { func TestGetOriginalNameserversFromSystem(t *testing.T) { configurator := &systemConfigurator{ - createdKeys: make(map[string]struct{}), + createdKeys: make(map[string]struct{}), + interfaceName: testInterfaceName, } _, err := configurator.getSystemDNSSettings() @@ -373,7 +380,8 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man sm.Start() configurator := &systemConfigurator{ - createdKeys: make(map[string]struct{}), + createdKeys: make(map[string]struct{}), + interfaceName: testInterfaceName, } cleanup := func() { @@ -381,10 +389,8 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man for key := range configurator.createdKeys { _ = removeTestDNSKey(key) } - // Also clean up old-format keys and local key in case they exist - _ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)) - _ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)) - _ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)) + // Also clean up local key in case it exists + _ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, testInterfaceName, localSuffix)) } return configurator, sm, cleanup @@ -493,3 +499,270 @@ func TestOriginalNameserversRouteAllTransition(t *testing.T) { }) } } + +func TestGetKeyWithInput(t *testing.T) { + tests := []struct { + name string + format string + iface string + key string + expected string + }{ + { + name: "search key", + format: netbirdDNSStateKeyFormat, + iface: "utun0", + key: searchSuffix, + expected: "State:/Network/Service/NetBird-utun0-Search/DNS", + }, + { + name: "match key", + format: netbirdDNSStateKeyFormat, + iface: "utun0", + key: matchSuffix, + expected: "State:/Network/Service/NetBird-utun0-Match/DNS", + }, + { + name: "local key", + format: netbirdDNSStateKeyFormat, + iface: "utun0", + key: localSuffix, + expected: "State:/Network/Service/NetBird-utun0-Local/DNS", + }, + { + name: "different interface", + format: netbirdDNSStateKeyFormat, + iface: "utun100", + key: searchSuffix, + expected: "State:/Network/Service/NetBird-utun100-Search/DNS", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := getKeyWithInput(tc.format, tc.iface, tc.key) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestNewHostManagerWithInterfaceName(t *testing.T) { + manager, err := newHostManager("utun42") + require.NoError(t, err) + assert.Equal(t, "utun42", manager.interfaceName) + assert.NotNil(t, manager.createdKeys) +} + +func TestNewHostManagerWithEmptyInterfaceName(t *testing.T) { + manager, err := newHostManager("") + require.Error(t, err) + assert.Nil(t, manager) + assert.Contains(t, err.Error(), "interfaceName must not be empty") +} + +func TestMultipleInterfacesGenerateDifferentKeys(t *testing.T) { + iface1 := "utun0" + iface2 := "utun1" + + for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} { + key1 := getKeyWithInput(netbirdDNSStateKeyFormat, iface1, suffix) + key2 := getKeyWithInput(netbirdDNSStateKeyFormat, iface2, suffix) + assert.NotEqual(t, key1, key2, "keys for different interfaces should differ (suffix=%s)", suffix) + assert.Contains(t, key1, iface1) + assert.Contains(t, key2, iface2) + } + + // Also check indexed format + key1 := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, iface1, matchSuffix, 0) + key2 := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, iface2, matchSuffix, 0) + assert.NotEqual(t, key1, key2) + assert.Contains(t, key1, iface1) + assert.Contains(t, key2, iface2) +} + +func TestShutdownStateIncludesInterfaceName(t *testing.T) { + state := &ShutdownState{ + InterfaceName: "utun42", + CreatedKeys: []string{"key1", "key2"}, + } + assert.Equal(t, "utun42", state.InterfaceName) + assert.Equal(t, "dns_state", state.Name()) +} + +func TestPrimaryServiceKeyFormatNotAffected(t *testing.T) { + // primaryServiceStateKeyFormat has only one %s placeholder for the service UUID. + // It must NOT be called with getKeyWithInput (which expects iface + key). + serviceUUID := "12345678-ABCD-1234-ABCD-123456789ABC" + result := fmt.Sprintf(primaryServiceStateKeyFormat, serviceUUID) + assert.Equal(t, "State:/Network/Service/12345678-ABCD-1234-ABCD-123456789ABC/DNS", result) +} + +// TestMultipleInstancesBatchedIsolation verifies that two instances with +// different interfaces each get their own batched keys and don't interfere. +func TestMultipleInstancesBatchedIsolation(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + iface1 := "utun991" + iface2 := "utun992" + + cfg1 := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + interfaceName: iface1, + } + cfg2 := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + interfaceName: iface2, + } + + defer func() { + for key := range cfg1.createdKeys { + _ = removeTestDNSKey(key) + } + for key := range cfg2.createdKeys { + _ = removeTestDNSKey(key) + } + }() + + domains1 := generateShortDomains(60) // forces 2 batches + domains2 := generateShortDomains(60) + + require.NoError(t, cfg1.addBatchedDomains(matchSuffix, domains1, netip.MustParseAddr("100.64.0.1"), 53, false)) + require.NoError(t, cfg2.addBatchedDomains(matchSuffix, domains2, netip.MustParseAddr("100.64.0.2"), 53, false)) + + // Verify cfg1 keys contain iface1, not iface2 + for key := range cfg1.createdKeys { + assert.Contains(t, key, iface1) + assert.NotContains(t, key, iface2) + } + + // Verify cfg2 keys contain iface2, not iface1 + for key := range cfg2.createdKeys { + assert.Contains(t, key, iface2) + assert.NotContains(t, key, iface1) + } + + // Verify no key overlap + for key := range cfg1.createdKeys { + _, exists := cfg2.createdKeys[key] + assert.False(t, exists, "key %s should not exist in both instances", key) + } + + // Verify all domains readable from each instance's keys + var got1, got2 []string + for key := range cfg1.createdKeys { + got1 = append(got1, readDomainsFromKey(t, key)...) + } + for key := range cfg2.createdKeys { + got2 = append(got2, readDomainsFromKey(t, key)...) + } + assert.Equal(t, len(domains1), len(got1), "all domains from instance 1 should be readable") + assert.Equal(t, len(domains2), len(got2), "all domains from instance 2 should be readable") +} + +func TestShutdownStateUnmarshalLegacyJSON(t *testing.T) { + tests := []struct { + name string + json string + expectedIface string + expectedKeys []string + expectedKeysCount int + }{ + { + name: "legacy format (PascalCase, no interface)", + json: `{"CreatedKeys":["State:/Network/Service/NetBird-Match/DNS","State:/Network/Service/NetBird-Search/DNS"]}`, + expectedIface: "", + expectedKeys: []string{"State:/Network/Service/NetBird-Match/DNS", "State:/Network/Service/NetBird-Search/DNS"}, + expectedKeysCount: 2, + }, + { + name: "new format (snake_case, with interface)", + json: `{"interface_name":"utun0","created_keys":["State:/Network/Service/NetBird-utun0-Match-0/DNS"]}`, + expectedIface: "utun0", + expectedKeys: []string{"State:/Network/Service/NetBird-utun0-Match-0/DNS"}, + expectedKeysCount: 1, + }, + { + name: "empty legacy state", + json: `{}`, + expectedIface: "", + expectedKeysCount: 0, + }, + { + name: "legacy with empty keys", + json: `{"CreatedKeys":[]}`, + expectedIface: "", + expectedKeysCount: 0, + }, + { + name: "mixed fields: new format wins when populated", + json: `{"interface_name":"utun0","created_keys":["new-key"],"CreatedKeys":["old-key"]}`, + expectedIface: "utun0", + expectedKeys: []string{"new-key"}, + expectedKeysCount: 1, + }, + { + name: "mixed fields: legacy fills in when new is empty", + json: `{"created_keys":[],"CreatedKeys":["old-key"]}`, + expectedIface: "", + expectedKeys: []string{"old-key"}, + expectedKeysCount: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var state ShutdownState + err := json.Unmarshal([]byte(tc.json), &state) + require.NoError(t, err) + + assert.Equal(t, tc.expectedIface, state.InterfaceName) + assert.Len(t, state.CreatedKeys, tc.expectedKeysCount) + if tc.expectedKeys != nil { + assert.Equal(t, tc.expectedKeys, state.CreatedKeys) + } + }) + } +} + +func TestShutdownStateLegacyCleanupWithKeys(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + // Simulate an old-version unclean shutdown: write legacy-format scutil keys, + // then verify that Cleanup() with legacy state (no InterfaceName, PascalCase + // CreatedKeys) discovers and removes them. + legacyKey := fmt.Sprintf("State:/Network/Service/NetBird-%s/DNS", matchSuffix) + + // Write a legacy key to scutil + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + err := configurator.addDNSState(legacyKey, "legacy.example.com", netip.MustParseAddr("100.64.0.1"), 53, false) + require.NoError(t, err) + + defer func() { + _ = removeTestDNSKey(legacyKey) + }() + + exists, err := checkDNSKeyExists(legacyKey) + require.NoError(t, err) + require.True(t, exists, "legacy key should exist before cleanup") + + // Simulate deserializing old state: unmarshal legacy JSON (PascalCase, no InterfaceName) + // to exercise the full backward-compat path end-to-end. + legacyJSON := []byte(`{"CreatedKeys":["` + legacyKey + `"]}`) + var state ShutdownState + require.NoError(t, json.Unmarshal(legacyJSON, &state)) + require.Empty(t, state.InterfaceName, "legacy state should have no interface name") + require.Contains(t, state.CreatedKeys, legacyKey, "legacy key should be deserialized from PascalCase JSON") + + err = state.Cleanup() + require.NoError(t, err) + + exists, err = checkDNSKeyExists(legacyKey) + require.NoError(t, err) + assert.False(t, exists, "legacy key should be removed after cleanup") +} diff --git a/client/internal/dns/server_darwin.go b/client/internal/dns/server_darwin.go index d5a018f09b7..1163773f8ee 100644 --- a/client/internal/dns/server_darwin.go +++ b/client/internal/dns/server_darwin.go @@ -3,5 +3,5 @@ package dns func (s *DefaultServer) initialize() (manager hostManager, err error) { - return newHostManager() + return newHostManager(s.wgInterface.Name()) } diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index f51b5cf8d16..70164ed8098 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -3,11 +3,39 @@ package dns import ( + "encoding/json" "fmt" + + log "github.com/sirupsen/logrus" ) type ShutdownState struct { - CreatedKeys []string + InterfaceName string `json:"interface_name,omitempty"` + CreatedKeys []string `json:"created_keys,omitempty"` +} + +// UnmarshalJSON implements custom JSON unmarshaling to handle backward compatibility. +// Old versions serialized CreatedKeys without JSON tags (as "CreatedKeys" in JSON), +// while the new format uses "created_keys". This ensures both formats are handled. +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + type Alias ShutdownState + aux := &Alias{} + if err := json.Unmarshal(data, aux); err != nil { + return err + } + *s = ShutdownState(*aux) + + // If CreatedKeys is empty, try legacy format (no JSON tags, PascalCase keys) + if len(s.CreatedKeys) == 0 { + var legacy struct { + CreatedKeys []string `json:"CreatedKeys"` + } + if err := json.Unmarshal(data, &legacy); err == nil && len(legacy.CreatedKeys) > 0 { + s.CreatedKeys = legacy.CreatedKeys + } + } + + return nil } func (s *ShutdownState) Name() string { @@ -15,9 +43,21 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - manager, err := newHostManager() - if err != nil { - return fmt.Errorf("create host manager: %w", err) + var manager *systemConfigurator + if s.InterfaceName != "" { + var err error + manager, err = newHostManager(s.InterfaceName) + if err != nil { + return fmt.Errorf("create host manager: %w", err) + } + } else { + // State from an older version without interface name. + // Create a bare configurator so discoverExistingKeys() can find and + // remove legacy non-scoped scutil keys (e.g. NetBird-Match/DNS). + log.Warn("dns shutdown state has no interface name, falling back to legacy scutil key discovery") + manager = &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } } for _, key := range s.CreatedKeys {