From c9243cda39b9ebf13ab6e8e7b7e82d84feb9e01c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 6 Feb 2026 18:55:31 +0800 Subject: [PATCH 1/8] Add TCP DNS support and clean up packet hook interface --- client/firewall/uspfilter/filter.go | 132 ++--- client/firewall/uspfilter/filter_test.go | 151 ++---- client/firewall/uspfilter/hooks_bench_test.go | 182 +++++++ client/firewall/uspfilter/rule.go | 4 +- client/firewall/uspfilter/tracer_test.go | 16 +- client/iface/device/device_filter.go | 19 +- client/iface/mocks/filter.go | 40 +- client/iface/mocks/iface/mocks/filter.go | 87 ---- client/internal/dns/handler_chain.go | 12 +- client/internal/dns/response_writer.go | 36 ++ client/internal/dns/server.go | 9 +- client/internal/dns/server_test.go | 6 +- client/internal/dns/service.go | 12 +- client/internal/dns/service_listener.go | 94 +++- client/internal/dns/service_listener_test.go | 89 ++++ client/internal/dns/service_memory.go | 73 ++- client/internal/dns/tcpstack.go | 459 ++++++++++++++++++ client/internal/dns/upstream.go | 142 +++++- client/internal/dns/upstream_android.go | 4 +- client/internal/dns/upstream_test.go | 177 +++++++ client/internal/dnsfwd/forwarder.go | 20 +- client/internal/engine.go | 1 + 22 files changed, 1370 insertions(+), 395 deletions(-) create mode 100644 client/firewall/uspfilter/hooks_bench_test.go delete mode 100644 client/iface/mocks/iface/mocks/filter.go create mode 100644 client/internal/dns/service_listener_test.go create mode 100644 client/internal/dns/tcpstack.go diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index df2e274ebce..6415af17fd0 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -140,6 +140,17 @@ type Manager struct { mtu uint16 mssClampValue uint16 mssClampEnabled bool + + // Only one hook per protocol is supported. Outbound direction only. + udpHookOut atomic.Pointer[packetHook] + tcpHookOut atomic.Pointer[packetHook] +} + +// packetHook stores a registered hook for a specific IP:port. +type packetHook struct { + ip netip.Addr + port uint16 + fn func([]byte) bool } // decoder for packages @@ -713,6 +724,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } case layers.LayerTypeTCP: + if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) { + return true + } // Clamp MSS on all TCP SYN packets, including those from local IPs. // SNATed routed traffic may appear as local IP but still requires clamping. if m.mssClampEnabled { @@ -895,38 +909,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt d.dnatOrigPort = 0 } -// udpHooksDrop checks if any UDP hooks should drop the packet func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { - m.mutex.RLock() - defer m.mutex.RUnlock() + return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData) +} - // Check specific destination IP first - if rules, exists := m.outgoingRules[dstIP]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } - } +func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { + return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData) +} - // Check IPv4 unspecified address - if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } +func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool { + if h == nil { + return false } - - // Check IPv6 unspecified address - if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists { - for _, rule := range rules { - if rule.udpHook != nil && portsMatch(rule.dPort, dport) { - return rule.udpHook(packetData) - } - } + if h.ip == dstIP && h.port == dport { + return h.fn(packetData) } - return false } @@ -1278,12 +1275,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d return rule.mgmtId, rule.drop, true } case layers.LayerTypeUDP: - // if rule has UDP hook (and if we are here we match this rule) - // we ignore rule.drop and call this hook - if rule.udpHook != nil { - return rule.mgmtId, rule.udpHook(packetData), true - } - if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { return rule.mgmtId, rule.drop, true } @@ -1342,65 +1333,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot return sourceMatched } -// AddUDPPacketHook calls hook when UDP packet from given direction matched -// -// Hook function returns flag which indicates should be the matched package dropped or not -func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string { - r := PeerRule{ - id: uuid.New().String(), - ip: ip, - protoLayer: layers.LayerTypeUDP, - dPort: &firewall.Port{Values: []uint16{dPort}}, - ipLayer: layers.LayerTypeIPv6, - udpHook: hook, - } - - if ip.Is4() { - r.ipLayer = layers.LayerTypeIPv4 +// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove. +func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { + if hook == nil { + m.udpHookOut.Store(nil) + return } - - m.mutex.Lock() - if in { - // Incoming UDP hooks are stored in allow rules map - if _, ok := m.incomingRules[r.ip]; !ok { - m.incomingRules[r.ip] = make(map[string]PeerRule) - } - m.incomingRules[r.ip][r.id] = r - } else { - if _, ok := m.outgoingRules[r.ip]; !ok { - m.outgoingRules[r.ip] = make(map[string]PeerRule) - } - m.outgoingRules[r.ip][r.id] = r - } - m.mutex.Unlock() - - return r.id + m.udpHookOut.Store(&packetHook{ + ip: ip, + port: dPort, + fn: hook, + }) } -// RemovePacketHook removes packet hook by given ID -func (m *Manager) RemovePacketHook(hookID string) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - // Check incoming hooks (stored in allow rules) - for _, arr := range m.incomingRules { - for _, r := range arr { - if r.id == hookID { - delete(arr, r.id) - return nil - } - } +// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove. +func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { + if hook == nil { + m.tcpHookOut.Store(nil) + return } - // Check outgoing hooks - for _, arr := range m.outgoingRules { - for _, r := range arr { - if r.id == hookID { - delete(arr, r.id) - return nil - } - } - } - return fmt.Errorf("hook with given id not found") + m.tcpHookOut.Store(&packetHook{ + ip: ip, + port: dPort, + fn: hook, + }) } // SetLogLevel sets the log level for the firewall manager diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 55a8e723cb4..cb1fe6f26f4 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" wgdevice "golang.zx2c4.com/wireguard/device" @@ -186,81 +187,50 @@ func TestManagerDeleteRule(t *testing.T) { } } -func TestAddUDPPacketHook(t *testing.T) { - tests := []struct { - name string - in bool - expDir fw.RuleDirection - ip netip.Addr - dPort uint16 - hook func([]byte) bool - expectedID string - }{ - { - name: "Test Outgoing UDP Packet Hook", - in: false, - expDir: fw.RuleDirectionOUT, - ip: netip.MustParseAddr("10.168.0.1"), - dPort: 8000, - hook: func([]byte) bool { return true }, - }, - { - name: "Test Incoming UDP Packet Hook", - in: true, - expDir: fw.RuleDirectionIN, - ip: netip.MustParseAddr("::1"), - dPort: 9000, - hook: func([]byte) bool { return false }, - }, - } +func TestSetUDPPacketHook(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger, nbiface.DefaultMTU) - require.NoError(t, err) + var called bool + manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool { + called = true + return true + }) - manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) + h := manager.udpHookOut.Load() + require.NotNil(t, h) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip) + assert.Equal(t, uint16(8000), h.port) + assert.True(t, h.fn(nil)) + assert.True(t, called) - var addedRule PeerRule - if tt.in { - // Incoming UDP hooks are stored in allow rules map - if len(manager.incomingRules[tt.ip]) != 1 { - t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip])) - return - } - for _, rule := range manager.incomingRules[tt.ip] { - addedRule = rule - } - } else { - if len(manager.outgoingRules[tt.ip]) != 1 { - t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip])) - return - } - for _, rule := range manager.outgoingRules[tt.ip] { - addedRule = rule - } - } + manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil) + assert.Nil(t, manager.udpHookOut.Load()) +} - if tt.ip.Compare(addedRule.ip) != 0 { - t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) - return - } - if tt.dPort != addedRule.dPort.Values[0] { - t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0]) - return - } - if layers.LayerTypeUDP != addedRule.protoLayer { - t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) - return - } - if addedRule.udpHook == nil { - t.Errorf("expected udpHook to be set") - return - } - }) - } +func TestSetTCPPacketHook(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + + var called bool + manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool { + called = true + return true + }) + + h := manager.tcpHookOut.Load() + require.NotNil(t, h) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip) + assert.Equal(t, uint16(53), h.port) + assert.True(t, h.fn(nil)) + assert.True(t, called) + + manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil) + assert.Nil(t, manager.tcpHookOut.Load()) } // TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added @@ -530,39 +500,12 @@ func TestRemovePacketHook(t *testing.T) { require.NoError(t, manager.Close(nil)) }() - // Add a UDP packet hook - hookFunc := func(data []byte) bool { return true } - hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc) - - // Assert the hook is added by finding it in the manager's outgoing rules - found := false - for _, arr := range manager.outgoingRules { - for _, rule := range arr { - if rule.id == hookID { - found = true - break - } - } - } + manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true }) - if !found { - t.Fatalf("The hook was not added properly.") - } + require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered") - // Now remove the packet hook - err = manager.RemovePacketHook(hookID) - if err != nil { - t.Fatalf("Failed to remove hook: %s", err) - } - - // Assert the hook is removed by checking it in the manager's outgoing rules - for _, arr := range manager.outgoingRules { - for _, rule := range arr { - if rule.id == hookID { - t.Fatalf("The hook was not removed properly.") - } - } - } + manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil) + assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed") } func TestProcessOutgoingHooks(t *testing.T) { @@ -592,8 +535,7 @@ func TestProcessOutgoingHooks(t *testing.T) { } hookCalled := false - hookID := manager.AddUDPPacketHook( - false, + manager.SetUDPPacketHook( netip.MustParseAddr("100.10.0.100"), 53, func([]byte) bool { @@ -601,7 +543,6 @@ func TestProcessOutgoingHooks(t *testing.T) { return true }, ) - require.NotEmpty(t, hookID) // Create test UDP packet ipv4 := &layers.IPv4{ diff --git a/client/firewall/uspfilter/hooks_bench_test.go b/client/firewall/uspfilter/hooks_bench_test.go new file mode 100644 index 00000000000..be6a8408b30 --- /dev/null +++ b/client/firewall/uspfilter/hooks_bench_test.go @@ -0,0 +1,182 @@ +package uspfilter + +import ( + "net" + "net/netip" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + nbiface "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +func buildUDPPacket(b *testing.B, srcIP, dstIP string, srcPort, dstPort uint16) []byte { + b.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + Protocol: layers.IPProtocolUDP, + } + udpLayer := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + b.Fatal(err) + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + if err := gopacket.SerializeLayers(buf, opts, ipLayer, udpLayer, gopacket.Payload([]byte("test"))); err != nil { + b.Fatal(err) + } + return buf.Bytes() +} + +func buildTCPPacket(b *testing.B, srcIP, dstIP string, srcPort, dstPort uint16) []byte { + b.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + Protocol: layers.IPProtocolTCP, + } + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + } + if err := tcpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + b.Fatal(err) + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + if err := gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte("test"))); err != nil { + b.Fatal(err) + } + return buf.Bytes() +} + +func newBenchManager(b *testing.B) *Manager { + b.Helper() + m, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, nbiface.DefaultMTU) + require.NoError(b, err) + return m +} + +// BenchmarkHooksDrop_UDPMatch measures the cost of the UDP hook check when the +// packet matches the registered hook (the DNS interception fast path). +func BenchmarkHooksDrop_UDPMatch(b *testing.B) { + m := newBenchManager(b) + m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) + + pkt := buildUDPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) + + b.ResetTimer() + b.ReportAllocs() + for b.Loop() { + m.udpHooksDrop(53, netip.MustParseAddr("100.10.255.254"), pkt) + } +} + +// BenchmarkHooksDrop_UDPMiss measures the cost when no UDP hook matches +// (common case for non-DNS traffic). +func BenchmarkHooksDrop_UDPMiss(b *testing.B) { + m := newBenchManager(b) + m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) + + pkt := buildUDPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 8080) + + b.ResetTimer() + b.ReportAllocs() + for b.Loop() { + m.udpHooksDrop(8080, netip.MustParseAddr("100.10.0.2"), pkt) + } +} + +// BenchmarkHooksDrop_TCPMatch measures the TCP hook check when matching (DNS TCP). +func BenchmarkHooksDrop_TCPMatch(b *testing.B) { + m := newBenchManager(b) + m.SetTCPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) + + pkt := buildTCPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) + + b.ResetTimer() + b.ReportAllocs() + for b.Loop() { + m.tcpHooksDrop(53, netip.MustParseAddr("100.10.255.254"), pkt) + } +} + +// BenchmarkHooksDrop_TCPMiss measures TCP hook check for non-matching traffic. +func BenchmarkHooksDrop_TCPMiss(b *testing.B) { + m := newBenchManager(b) + m.SetTCPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) + + pkt := buildTCPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 443) + + b.ResetTimer() + b.ReportAllocs() + for b.Loop() { + m.tcpHooksDrop(443, netip.MustParseAddr("100.10.0.2"), pkt) + } +} + +// BenchmarkHooksDrop_NoHooks measures the cost when no hooks are registered +// (the baseline for all non-DNS traffic). +func BenchmarkHooksDrop_NoHooks(b *testing.B) { + m := newBenchManager(b) + + pkt := buildUDPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 8080) + + b.ResetTimer() + b.ReportAllocs() + for b.Loop() { + m.udpHooksDrop(8080, netip.MustParseAddr("100.10.0.2"), pkt) + m.tcpHooksDrop(8080, netip.MustParseAddr("100.10.0.2"), pkt) + } +} + +// BenchmarkFilterOutbound_WithHooks benchmarks the full FilterOutbound path +// with both UDP and TCP hooks registered (the real-world DNS scenario). +func BenchmarkFilterOutbound_WithHooks(b *testing.B) { + m := newBenchManager(b) + m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) + m.SetTCPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) + + udpDNS := buildUDPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) + tcpDNS := buildTCPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) + tcpHTTPS := buildTCPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 443) + + b.Run("udp_dns_match", func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + m.FilterOutbound(udpDNS, len(udpDNS)) + } + }) + + b.Run("tcp_dns_match", func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + m.FilterOutbound(tcpDNS, len(tcpDNS)) + } + }) + + b.Run("tcp_https_miss", func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + m.FilterOutbound(tcpHTTPS, len(tcpHTTPS)) + } + }) +} diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index dbe3a78583a..08d68a78ece 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -18,9 +18,7 @@ type PeerRule struct { protoLayer gopacket.LayerType sPort *firewall.Port dPort *firewall.Port - drop bool - - udpHook func([]byte) bool + drop bool } // ID returns the rule id diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index d9f9f1aa80c..657f96fc0f6 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) { { name: "UDPTraffic_WithHook", setup: func(m *Manager) { - hookFunc := func([]byte) bool { - return true - } - m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc) + m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { + return true // drop (intercepted by hook) + }) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN) + return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT) }, expectedStages: []PacketStage{ StageReceived, - StageInboundPortDNAT, - StageInbound1to1NAT, - StageConntrack, - StageRouting, - StagePeerACL, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: false, diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 708f38d2620..4357d191668 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -15,14 +15,17 @@ type PacketFilter interface { // FilterInbound filter incoming packets from external sources to host FilterInbound(packetData []byte, size int) bool - // AddUDPPacketHook calls hook when UDP packet from given direction matched - // - // Hook function returns flag which indicates should be the matched package dropped or not. - // Hook function receives raw network packet data as argument. - AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string - - // RemovePacketHook removes hook by ID - RemovePacketHook(hookID string) error + // SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port. + // Hook function returns true if the packet should be dropped. + // Only one UDP hook is supported; calling again replaces the previous hook. + // Pass nil hook to remove. + SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) + + // SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port. + // Hook function returns true if the packet should be dropped. + // Only one TCP hook is supported; calling again replaces the previous hook. + // Pass nil hook to remove. + SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) } // FilteredDevice to override Read or Write of packets diff --git a/client/iface/mocks/filter.go b/client/iface/mocks/filter.go index 566068aa578..5ae98039c0e 100644 --- a/client/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { return m.recorder } -// AddUDPPacketHook mocks base method. -func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string { +// SetUDPPacketHook mocks base method. +func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(string) - return ret0 + m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2) +} + +// SetUDPPacketHook indicates an expected call of SetUDPPacketHook. +func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2) +} + +// SetTCPPacketHook mocks base method. +func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2) } -// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. -func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// SetTCPPacketHook indicates an expected call of SetTCPPacketHook. +func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2) } // FilterInbound mocks base method. @@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1) } - -// RemovePacketHook mocks base method. -func (m *MockPacketFilter) RemovePacketHook(arg0 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemovePacketHook", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemovePacketHook indicates an expected call of RemovePacketHook. -func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) -} diff --git a/client/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go deleted file mode 100644 index 291ab9ab557..00000000000 --- a/client/iface/mocks/iface/mocks/filter.go +++ /dev/null @@ -1,87 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockPacketFilter is a mock of PacketFilter interface. -type MockPacketFilter struct { - ctrl *gomock.Controller - recorder *MockPacketFilterMockRecorder -} - -// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter. -type MockPacketFilterMockRecorder struct { - mock *MockPacketFilter -} - -// NewMockPacketFilter creates a new mock instance. -func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter { - mock := &MockPacketFilter{ctrl: ctrl} - mock.recorder = &MockPacketFilterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { - return m.recorder -} - -// AddUDPPacketHook mocks base method. -func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) -} - -// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. -func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) -} - -// FilterInbound mocks base method. -func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FilterInbound", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// FilterInbound indicates an expected call of FilterInbound. -func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0) -} - -// FilterOutbound mocks base method. -func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FilterOutbound", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// FilterOutbound indicates an expected call of FilterOutbound. -func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0) -} - -// SetNetwork mocks base method. -func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetNetwork", arg0) -} - -// SetNetwork indicates an expected call of SetNetwork. -func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0) -} diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 06a2056b159..f183052cd6a 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -195,10 +195,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { startTime := time.Now() requestID := resutil.GenerateRequestID() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": requestID, "dns_id": fmt.Sprintf("%04x", r.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) question := r.Question[0] qname := strings.ToLower(question.Name) @@ -261,9 +265,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q meta += " " + k + "=" + v } - logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s", + logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s", qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer), - meta, time.Since(startTime)) + cw.response.Len(), meta, time.Since(startTime)) } func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { diff --git a/client/internal/dns/response_writer.go b/client/internal/dns/response_writer.go index edc65a5d9e8..1268b85f5a6 100644 --- a/client/internal/dns/response_writer.go +++ b/client/internal/dns/response_writer.go @@ -104,3 +104,39 @@ func (r *responseWriter) TsigTimersOnly(bool) { // After a call to Hijack(), the DNS package will not do anything with the connection. func (r *responseWriter) Hijack() { } + +// truncationAwareWriter wraps a UDP responseWriter and starts the TCP DNS +// stack when a truncated response is about to be sent. This ensures the +// TCP stack is ready when the client retries over TCP. +type truncationAwareWriter struct { + responseWriter + tcpDNS *tcpDNSServer +} + +// WriteMsg checks if the response is truncated and starts the TCP stack if needed. +func (w *truncationAwareWriter) WriteMsg(msg *dns.Msg) error { + if msg.MsgHdr.Truncated && w.tcpDNS != nil { + w.tcpDNS.EnsureRunning() + } + return w.responseWriter.WriteMsg(msg) +} + +// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging. +func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr { + var srcIP net.IP + if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil { + srcIP = ipv4.(*layers.IPv4).SrcIP + } else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil { + srcIP = ipv6.(*layers.IPv6).SrcIP + } + + var srcPort int + if udp := packet.Layer(layers.LayerTypeUDP); udp != nil { + srcPort = int(udp.(*layers.UDP).SrcPort) + } + + if srcIP == nil { + return nil + } + return &net.UDPAddr{IP: srcIP, Port: srcPort} +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 3c47f4ee639..e14824231f4 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -130,6 +130,7 @@ type registeredHandlerMap map[types.HandlerID]handlerWrapper // DefaultServerConfig holds configuration parameters for NewDefaultServer type DefaultServerConfig struct { WgInterface WGIface + Firewall DNSFirewall CustomAddress string StatusRecorder *peer.Status StateManager *statemanager.Manager @@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default if config.WgInterface.IsUserspaceBind() { dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(config.WgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort, config.Firewall) } server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) @@ -396,7 +397,11 @@ func (s *DefaultServer) Stop() { } func (s *DefaultServer) disableDNS() error { - defer s.service.Stop() + defer func() { + if err := s.service.Stop(); err != nil { + log.Errorf("failed to stop DNS service: %v", err) + } + }() if s.isUsingNoopHostManager() { return nil diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index d3b0c250d22..f77f6e89812 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes() - packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - packetfilter.EXPECT().RemovePacketHook(gomock.Any()) + packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() if err := wgIface.SetFilter(packetfilter); err != nil { t.Errorf("set packet filter: %v", err) @@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand type mockService struct{} func (m *mockService) Listen() error { return nil } -func (m *mockService) Stop() {} +func (m *mockService) Stop() error { return nil } func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") } func (m *mockService) RuntimePort() int { return 53 } func (m *mockService) RegisterMux(string, dns.Handler) {} diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index 6a76c53e3f3..bd07110584e 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -4,15 +4,25 @@ import ( "net/netip" "github.com/miekg/dns" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" ) const ( DefaultPort = 53 ) +// DNSFirewall provides DNAT capabilities for DNS port redirection. +// This is used when the DNS server cannot bind port 53 directly +// and needs firewall rules to redirect traffic. +type DNSFirewall interface { + AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error + RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error +} + type service interface { Listen() error - Stop() + Stop() error RegisterMux(domain string, handler dns.Handler) DeregisterMux(key string) RuntimePort() int diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index f7ddfd40f43..315d4f09219 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -10,9 +10,13 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ) @@ -31,25 +35,33 @@ type serviceViaListener struct { dnsMux *dns.ServeMux customAddr *netip.AddrPort server *dns.Server + tcpServer *dns.Server listenIP netip.Addr listenPort uint16 listenerIsRunning bool listenerFlagLock sync.Mutex ebpfService ebpfMgr.Manager + firewall DNSFirewall + tcpDNATConfigured bool } -func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener { +func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, firewall DNSFirewall) *serviceViaListener { mux := dns.NewServeMux() s := &serviceViaListener{ wgInterface: wgIface, dnsMux: mux, customAddr: customAddr, + firewall: firewall, server: &dns.Server{ Net: "udp", Handler: mux, UDPSize: 65535, }, + tcpServer: &dns.Server{ + Net: "tcp", + Handler: mux, + }, } return s @@ -70,43 +82,75 @@ func (s *serviceViaListener) Listen() error { return fmt.Errorf("eval listen address: %w", err) } s.listenIP = s.listenIP.Unmap() - s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort))) - log.Debugf("starting dns on %s", s.server.Addr) + addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort))) + s.server.Addr = addr + s.tcpServer.Addr = addr + + log.Debugf("starting dns on %s (UDP + TCP)", addr) go func() { s.setListenerStatus(true) defer s.setListenerStatus(false) - err := s.server.ListenAndServe() - if err != nil { - log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err) + if err := s.server.ListenAndServe(); err != nil { + log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err) + } + }() + + go func() { + if err := s.tcpServer.ListenAndServe(); err != nil { + log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err) } }() + // When eBPF redirects UDP port 53 to our listen port, TCP still needs + // a DNAT rule because eBPF only handles UDP. + if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort { + if err := s.firewall.AddInboundDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + s.tcpDNATConfigured = true + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort) + } + } + return nil } -func (s *serviceViaListener) Stop() { +func (s *serviceViaListener) Stop() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() if !s.listenerIsRunning { - return + return nil } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err := s.server.ShutdownContext(ctx) - if err != nil { - log.Errorf("stopping dns server listener returned an error: %v", err) + var merr *multierror.Error + + if err := s.server.ShutdownContext(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err)) + } + + if err := s.tcpServer.ShutdownContext(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err)) + } + + if s.tcpDNATConfigured && s.firewall != nil { + if err := s.firewall.RemoveInboundDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + s.tcpDNATConfigured = false } if s.ebpfService != nil { - err = s.ebpfService.FreeDNSFwd() - if err != nil { - log.Errorf("stopping traffic forwarder returned an error: %v", err) + if err := s.ebpfService.FreeDNSFwd(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err)) } } + + return nberrors.FormatErrorOrNil(merr) } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { @@ -187,18 +231,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) { } func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool { - addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port)) - udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) - probeListener, err := net.ListenUDP("udp", udpAddr) + addrPort := netip.AddrPortFrom(ip, uint16(port)) + + udpAddr := net.UDPAddrFromAddrPort(addrPort) + udpLn, err := net.ListenUDP("udp", udpAddr) if err != nil { - log.Warnf("binding dns on %s is not available, error: %s", addrString, err) + log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err) return false } + if err := udpLn.Close(); err != nil { + log.Debugf("close UDP probe listener: %s", err) + } - err = probeListener.Close() + tcpAddr := net.TCPAddrFromAddrPort(addrPort) + tcpLn, err := net.ListenTCP("tcp", tcpAddr) if err != nil { - log.Errorf("got an error closing the probe listener, error: %s", err) + log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err) + return false + } + if err := tcpLn.Close(); err != nil { + log.Debugf("close TCP probe listener: %s", err) } + return true } diff --git a/client/internal/dns/service_listener_test.go b/client/internal/dns/service_listener_test.go new file mode 100644 index 00000000000..51f52020f19 --- /dev/null +++ b/client/internal/dns/service_listener_test.go @@ -0,0 +1,89 @@ +package dns + +import ( + "fmt" + "net" + "net/netip" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServiceViaListener_TCPAndUDP(t *testing.T) { + handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("192.0.2.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + // Create a service using a custom address to avoid needing root + svc := newServiceViaListener(nil, nil, nil) + svc.dnsMux.Handle(".", handler) + + // Find a free port by binding and releasing + udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0)) + udpLn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + t.Skip("cannot bind to 127.0.0.153, skipping") + } + port := uint16(udpLn.LocalAddr().(*net.UDPAddr).Port) + require.NoError(t, udpLn.Close()) + + // Check TCP is also available on this port + tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port)) + tcpLn, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + t.Skip("cannot bind TCP on same port, skipping") + } + require.NoError(t, tcpLn.Close()) + + addr := fmt.Sprintf("%s:%d", customIP, port) + svc.server.Addr = addr + svc.tcpServer.Addr = addr + svc.listenIP = customIP + svc.listenPort = port + + go func() { + if err := svc.server.ListenAndServe(); err != nil { + t.Logf("udp server: %v", err) + } + }() + go func() { + if err := svc.tcpServer.ListenAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + svc.listenerIsRunning = true + + defer svc.Stop() + + // Wait for servers to start + time.Sleep(100 * time.Millisecond) + + q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + // Test UDP query + udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second} + udpResp, _, err := udpClient.Exchange(q, addr) + require.NoError(t, err, "UDP query should succeed") + require.NotNil(t, udpResp) + require.NotEmpty(t, udpResp.Answer) + assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP") + + // Test TCP query + tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second} + tcpResp, _, err := tcpClient.Exchange(q, addr) + require.NoError(t, err, "TCP query should succeed") + require.NotNil(t, tcpResp) + require.NotEmpty(t, tcpResp.Answer) + assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP") +} diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 6ef0ab5268f..d6cc494b888 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -1,6 +1,7 @@ package dns import ( + "errors" "fmt" "net/netip" "sync" @@ -10,6 +11,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -18,7 +20,8 @@ type ServiceViaMemory struct { dnsMux *dns.ServeMux runtimeIP netip.Addr runtimePort int - udpFilterHookID string + tcpDNS *tcpDNSServer + tcpHookSet bool listenerIsRunning bool listenerFlagLock sync.Mutex } @@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { if err != nil { log.Errorf("get last ip from network: %v", err) } - s := &ServiceViaMemory{ + + return &ServiceViaMemory{ wgInterface: wgIface, dnsMux: dns.NewServeMux(), - runtimeIP: lastIP, runtimePort: DefaultPort, } - return s } func (s *ServiceViaMemory) Listen() error { @@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error { return nil } - var err error - s.udpFilterHookID, err = s.filterDNSTraffic() - if err != nil { - return fmt.Errorf("filter dns traffice: %w", err) + if err := s.filterDNSTraffic(); err != nil { + return fmt.Errorf("filter dns traffic: %w", err) } s.listenerIsRunning = true @@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error { return nil } -func (s *ServiceViaMemory) Stop() { +func (s *ServiceViaMemory) Stop() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() if !s.listenerIsRunning { - return + return nil + } + + filter := s.wgInterface.GetFilter() + if filter != nil { + filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil) + if s.tcpHookSet { + filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil) + } } - if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil { - log.Errorf("unable to remove DNS packet hook: %s", err) + if s.tcpDNS != nil { + s.tcpDNS.Stop() } s.listenerIsRunning = false + + return nil } func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { @@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr { return s.runtimeIP } -func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { +func (s *ServiceViaMemory) filterDNSTraffic() error { filter := s.wgInterface.GetFilter() if filter == nil { - return "", fmt.Errorf("can't set DNS filter, filter not initialized") + return errors.New("DNS filter not initialized") + } + + // Create TCP DNS server lazily here since the device may not exist at construction time. + if s.tcpDNS == nil { + if dev := s.wgInterface.GetDevice(); dev != nil { + // MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact. + s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU) + } } firstLayerDecoder := layers.LayerTypeIPv4 @@ -100,10 +118,8 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { } hook := func(packetData []byte) bool { - // Decode the packet packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) - // Get the UDP layer udpLayer := packet.Layer(layers.LayerTypeUDP) udp := udpLayer.(*layers.UDP) @@ -113,13 +129,28 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { return true } - writer := responseWriter{ - packet: packet, - device: s.wgInterface.GetDevice().Device, + writer := &truncationAwareWriter{ + responseWriter: responseWriter{ + remote: remoteAddrFromPacket(packet), + packet: packet, + device: s.wgInterface.GetDevice().Device, + }, + tcpDNS: s.tcpDNS, } - go s.dnsMux.ServeDNS(&writer, msg) + go s.dnsMux.ServeDNS(writer, msg) return true } - return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil + filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook) + + if s.tcpDNS != nil { + tcpHook := func(packetData []byte) bool { + s.tcpDNS.InjectPacket(packetData) + return true + } + filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook) + s.tcpHookSet = true + } + + return nil } diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go new file mode 100644 index 00000000000..fc542b1edf3 --- /dev/null +++ b/client/internal/dns/tcpstack.go @@ -0,0 +1,459 @@ +package dns + +import ( + "errors" + "fmt" + "io" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + dnsTCPReceiveWindow = 8192 + dnsTCPMaxInFlight = 16 + dnsTCPIdleTimeout = 30 * time.Second + dnsTCPReadTimeout = 5 * time.Second +) + +// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack. +// It is started lazily when a truncated DNS response is detected and shuts down +// after a period of inactivity to conserve resources. +type tcpDNSServer struct { + mu sync.Mutex + s *stack.Stack + ep *dnsEndpoint + mux *dns.ServeMux + tunDev tun.Device + ip netip.Addr + port uint16 + mtu uint16 + + running bool + closed bool + timerID uint64 + timer *time.Timer +} + +func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer { + return &tcpDNSServer{ + mux: mux, + tunDev: tunDev, + ip: ip, + port: port, + mtu: mtu, + } +} + +// EnsureRunning starts the TCP stack if not already running and resets the idle timer. +func (t *tcpDNSServer) EnsureRunning() { + t.mu.Lock() + defer t.mu.Unlock() + + if t.closed { + return + } + + if t.running { + t.resetTimerLocked() + return + } + + if err := t.startLocked(); err != nil { + log.Errorf("failed to start TCP DNS stack: %v", err) + return + } + + t.running = true + t.resetTimerLocked() + log.Debugf("TCP DNS stack started on %s:%d", t.ip, t.port) +} + +// InjectPacket ensures the stack is running and delivers a raw IP packet into +// the gvisor stack for TCP processing. Combining both operations under a single +// lock prevents a race where the idle timer could stop the stack between +// EnsureRunning and delivery. +func (t *tcpDNSServer) InjectPacket(payload []byte) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.closed { + return + } + + if !t.running { + if err := t.startLocked(); err != nil { + log.Errorf("failed to start TCP DNS stack: %v", err) + return + } + t.running = true + log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload)) + } + t.resetTimerLocked() + + ep := t.ep + if ep == nil || ep.dispatcher == nil { + return + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + // DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef. + ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) +} + +// Stop tears down the gvisor stack and releases resources permanently. +// After Stop, EnsureRunning becomes a no-op. +func (t *tcpDNSServer) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopLocked() + t.closed = true +} + +func (t *tcpDNSServer) startLocked() error { + // TODO: add ipv6.NewProtocol when IPv6 overlay support lands. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: false, + }) + + nicID := tcpip.NICID(1) + ep := &dnsEndpoint{ + tunDev: t.tunDev, + } + ep.mtu.Store(uint32(t.mtu)) + + if err := s.CreateNIC(nicID, ep); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("create NIC: %v", err) + } + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(t.ip.AsSlice()), + PrefixLen: 32, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("add protocol address: %s", err) + } + + if err := s.SetPromiscuousMode(nicID, true); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("set promiscuous mode: %s", err) + } + if err := s.SetSpoofing(nicID, true); err != nil { + s.Close() + s.Wait() + return fmt.Errorf("set spoofing: %s", err) + } + + defaultSubnet, err := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), + ) + if err != nil { + s.Close() + s.Wait() + return fmt.Errorf("create default subnet: %w", err) + } + + s.SetRouteTable([]tcpip.Route{ + {Destination: defaultSubnet, NIC: nicID}, + }) + + tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) { + t.handleTCPDNS(r) + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + t.s = s + t.ep = ep + return nil +} + +func (t *tcpDNSServer) stopLocked() { + if !t.running { + return + } + + if t.timer != nil { + t.timer.Stop() + t.timer = nil + } + + if t.s != nil { + t.s.Close() + t.s.Wait() + t.s = nil + } + t.ep = nil + t.running = false + + log.Debugf("TCP DNS stack stopped") +} + +func (t *tcpDNSServer) resetTimerLocked() { + if t.timer != nil { + t.timer.Stop() + } + t.timerID++ + id := t.timerID + t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() { + t.mu.Lock() + defer t.mu.Unlock() + + // Only stop if this timer is still the active one. + // A racing EnsureRunning may have replaced it. + if t.timerID != id { + return + } + t.stopLocked() + }) +} + +func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) { + id := r.ID() + + wq := waiter.Queue{} + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + log.Debugf("TCP DNS: failed to create endpoint: %v", epErr) + r.Complete(true) + return + } + r.Complete(false) + + conn := gonet.NewTCPConn(&wq, ep) + defer func() { + if err := conn.Close(); err != nil { + log.Tracef("TCP DNS: close conn: %v", err) + } + }() + + // Reset idle timer on activity + t.mu.Lock() + t.resetTimerLocked() + t.mu.Unlock() + + localAddr := &net.TCPAddr{ + IP: id.LocalAddress.AsSlice(), + Port: int(id.LocalPort), + } + remoteAddr := &net.TCPAddr{ + IP: id.RemoteAddress.AsSlice(), + Port: int(id.RemotePort), + } + + for { + if err := conn.SetDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil { + log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err) + break + } + + msg, err := readTCPDNSMessage(conn) + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err) + } + break + } + + writer := &tcpResponseWriter{ + conn: conn, + localAddr: localAddr, + remoteAddr: remoteAddr, + } + t.mux.ServeDNS(writer, msg) + } +} + +// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device. +type dnsEndpoint struct { + dispatcher stack.NetworkDispatcher + tunDev tun.Device + mtu atomic.Uint32 +} + +func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher } +func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil } +func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() } +func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone } +func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 } +func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" } +func (e *dnsEndpoint) Wait() {} +func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } +func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) {} +func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true } +func (e *dnsEndpoint) Close() {} +func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) {} +func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) } +func (e *dnsEndpoint) SetOnCloseAction(func()) {} + +const tunPacketOffset = 40 + +func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + var written int + for _, pkt := range pkts.AsSlice() { + data := stack.PayloadSince(pkt.NetworkHeader()) + if data == nil { + continue + } + + raw := data.AsSlice() + buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw)) + buf = append(buf, raw...) + data.Release() + + if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil { + log.Tracef("TCP DNS endpoint: failed to write packet: %v", err) + continue + } + written++ + } + return written, nil +} + +// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections. +type tcpResponseWriter struct { + conn *gonet.TCPConn + localAddr net.Addr + remoteAddr net.Addr +} + +func (w *tcpResponseWriter) LocalAddr() net.Addr { + return w.localAddr +} + +func (w *tcpResponseWriter) RemoteAddr() net.Addr { + return w.remoteAddr +} + +func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error { + data, err := msg.Pack() + if err != nil { + return fmt.Errorf("pack: %w", err) + } + + // DNS TCP: 2-byte length prefix + message + buf := make([]byte, 2+len(data)) + buf[0] = byte(len(data) >> 8) + buf[1] = byte(len(data)) + copy(buf[2:], data) + + if _, err = w.conn.Write(buf); err != nil { + return err + } + return nil +} + +func (w *tcpResponseWriter) Write(data []byte) (int, error) { + buf := make([]byte, 2+len(data)) + buf[0] = byte(len(data) >> 8) + buf[1] = byte(len(data)) + copy(buf[2:], data) + return w.conn.Write(buf) +} + +func (w *tcpResponseWriter) Close() error { + return w.conn.Close() +} + +func (w *tcpResponseWriter) TsigStatus() error { return nil } +func (w *tcpResponseWriter) TsigTimersOnly(bool) {} +func (w *tcpResponseWriter) Hijack() {} + +// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed). +func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) { + // DNS over TCP uses a 2-byte length prefix + lenBuf := make([]byte, 2) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return nil, fmt.Errorf("read length: %w", err) + } + + msgLen := int(lenBuf[0])<<8 | int(lenBuf[1]) + if msgLen == 0 || msgLen > 65535 { + return nil, fmt.Errorf("invalid message length: %d", msgLen) + } + + msgBuf := make([]byte, msgLen) + if _, err := io.ReadFull(conn, msgBuf); err != nil { + return nil, fmt.Errorf("read message: %w", err) + } + + msg := new(dns.Msg) + if err := msg.Unpack(msgBuf); err != nil { + return nil, fmt.Errorf("unpack: %w", err) + } + return msg, nil +} + +// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging. +// Supports both IPv4 and IPv6. +func srcAddrFromPacket(pkt []byte) netip.AddrPort { + if len(pkt) == 0 { + return netip.AddrPort{} + } + + var srcIP netip.Addr + var transportOffset int + + switch header.IPVersion(pkt) { + case 4: + if len(pkt) < header.IPv4MinimumSize { + return netip.AddrPort{} + } + hdr := header.IPv4(pkt) + src := hdr.SourceAddress() + var ok bool + srcIP, ok = netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.AddrPort{} + } + transportOffset = int(hdr.HeaderLength()) + case 6: + if len(pkt) < header.IPv6MinimumSize { + return netip.AddrPort{} + } + hdr := header.IPv6(pkt) + src := hdr.SourceAddress() + var ok bool + srcIP, ok = netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.AddrPort{} + } + transportOffset = header.IPv6MinimumSize + default: + return netip.AddrPort{} + } + + // TCP source port is the first 2 bytes of the transport header. + if len(pkt) < transportOffset+2 { + return netip.AddrPort{} + } + srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1]) + return netip.AddrPortFrom(srcIP.Unmap(), srcPort) +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 5b813513272..c61a569a5cc 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -45,6 +45,53 @@ const ( const testRecord = "com." +const ( + protoUDP = "udp" + protoTCP = "tcp" +) + +type dnsProtocolKey struct{} + +// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context. +func contextWithDNSProtocol(ctx context.Context, network string) context.Context { + return context.WithValue(ctx, dnsProtocolKey{}, network) +} + +// dnsProtocolFromContext retrieves the inbound DNS protocol from context. +func dnsProtocolFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok { + return v + } + return "" +} + +type upstreamProtocolKey struct{} + +// upstreamProtocolResult holds the protocol used for the upstream exchange. +// Stored as a pointer in context so the exchange function can set it. +type upstreamProtocolResult struct { + protocol string +} + +// contextWithupstreamProtocolResult stores a mutable result holder in the context. +func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { + r := &upstreamProtocolResult{} + return context.WithValue(ctx, upstreamProtocolKey{}, r), r +} + +// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present. +func setUpstreamProtocol(ctx context.Context, protocol string) { + if ctx == nil { + return + } + if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil { + r.protocol = protocol + } +} + type upstreamClient interface { exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) } @@ -138,7 +185,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - ok, failures := u.tryUpstreamServers(w, r, logger) + // Propagate inbound protocol so upstream exchange can use TCP directly + // when the request came in over TCP. + ctx := u.ctx + if addr := w.RemoteAddr(); addr != nil { + network := addr.Network() + ctx = contextWithDNSProtocol(ctx, network) + resutil.SetMeta(w, "protocol", network) + } + + ok, failures := u.tryUpstreamServers(ctx, w, r, logger) if len(failures) > 0 { u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger) } @@ -153,7 +209,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } } -func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { +func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { timeout := u.upstreamTimeout if len(u.upstreamServers) > 1 { maxTotal := 5 * time.Second @@ -168,7 +224,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M var failures []upstreamFailure for _, upstream := range u.upstreamServers { - if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil { + if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil { failures = append(failures, *failure) } else { return true, failures @@ -178,15 +234,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M } // queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. -func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { var rm *dns.Msg var t time.Duration var err error var startTime time.Time + var upstreamProto *upstreamProtocolResult func() { - ctx, cancel := context.WithTimeout(u.ctx, timeout) + ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() + ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) startTime = time.Now() rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) }() @@ -203,7 +261,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) + u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) return nil } @@ -220,10 +278,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add return &upstreamFailure{upstream: upstream, reason: reason} } -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool { u.successCount.Add(1) resutil.SetMeta(w, "upstream", upstream.String()) + if upstreamProto != nil && upstreamProto.protocol != "" { + resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol) + } // Clear Zero bit from external responses to prevent upstream servers from // manipulating our internal fallthrough signaling mechanism @@ -430,8 +491,21 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. +// If the inbound request came over TCP (via context), it skips the UDP attempt. // If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { + // If the request came in over TCP, go straight to TCP upstream. + if dnsProtocolFromContext(ctx) == protoTCP { + tcpClient := *client + tcpClient.Net = protoTCP + rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream) + if err != nil { + return nil, t, fmt.Errorf("with tcp: %w", err) + } + setUpstreamProtocol(ctx, protoTCP) + return rm, t, nil + } + // MTU - ip + udp headers // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling. client.UDPSize = uint16(currentMTU - (60 + 8)) @@ -453,25 +527,37 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u } if rm == nil || !rm.MsgHdr.Truncated { + setUpstreamProtocol(ctx, protoUDP) return rm, t, nil } log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - client.Net = "tcp" + tcpClient := *client + tcpClient.Net = protoTCP if ctx == nil { - rm, t, err = client.Exchange(r, upstream) + rm, t, err = tcpClient.Exchange(r, upstream) } else { - rm, t, err = client.ExchangeContext(ctx, r, upstream) + rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream) } if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } - // TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP + setUpstreamProtocol(ctx, protoTCP) + + // Request came in over UDP but response was fetched via TCP. + // Truncate to fit the client's UDP buffer. + maxSize := dns.MinMsgSize + if opt := r.IsEdns0(); opt != nil { + maxSize = int(opt.UDPSize()) + } + if rm.Len() > maxSize { + rm.Truncate(maxSize) + } return rm, t, nil } @@ -479,7 +565,17 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u // ExchangeWithNetstack performs a DNS exchange using netstack for dialing. // This is needed when netstack is enabled to reach peer IPs through the tunnel. func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { - reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp") + // If request came in over TCP, go straight to TCP upstream + if dnsProtocolFromContext(ctx) == protoTCP { + rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP) + if err != nil { + return nil, err + } + setUpstreamProtocol(ctx, protoTCP) + return rm, nil + } + + reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP) if err != nil { return nil, err } @@ -488,9 +584,29 @@ func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, if reply != nil && reply.MsgHdr.Truncated { log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - return netstackExchange(ctx, nsNet, r, upstream, "tcp") + + rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP) + if err != nil { + return nil, err + } + + setUpstreamProtocol(ctx, protoTCP) + + // Request came in over UDP but response was fetched via TCP. + // Truncate to fit the client's UDP buffer. + maxSize := dns.MinMsgSize + if opt := r.IsEdns0(); opt != nil { + maxSize = int(opt.UDPSize()) + } + if rm.Len() > maxSize { + rm.Truncate(maxSize) + } + + return rm, nil } + setUpstreamProtocol(ctx, protoUDP) + return reply, nil } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index d7cff377bf0..ee1ca42fe2a 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin upstreamExchangeClient := &dns.Client{ Timeout: ClientTimeout, } - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream) } // exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN @@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri Timeout: timeout, } - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index ab164c30b8d..cb9fdf51786 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -475,3 +475,180 @@ func TestFormatFailures(t *testing.T) { }) } } + +func TestDNSProtocolContext(t *testing.T) { + t.Run("roundtrip udp", func(t *testing.T) { + ctx := contextWithDNSProtocol(context.Background(), protoUDP) + assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx)) + }) + + t.Run("roundtrip tcp", func(t *testing.T) { + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx)) + }) + + t.Run("missing returns empty", func(t *testing.T) { + assert.Equal(t, "", dnsProtocolFromContext(context.Background())) + }) +} + +func TestExchangeWithFallback_TCPContext(t *testing.T) { + // Start a local DNS server that responds on TCP only + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpServer := &dns.Server{ + Addr: "127.0.0.1:0", + Net: "tcp", + Handler: tcpHandler, + } + + tcpLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tcpServer.Listener = tcpLn + + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = tcpServer.Shutdown() + }() + + upstream := tcpLn.Addr().String() + + // With TCP context, should connect directly via TCP without trying UDP + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, upstream) + require.NoError(t, err) + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer) + assert.Contains(t, rm.Answer[0].String(), "10.0.0.1") +} + +func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) { + // Start a server on both UDP and TCP. + // The handler returns a small response that works on both. + handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.3"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{ + PacketConn: udpPC, + Net: "udp", + Handler: handler, + } + + tcpLn, err := net.Listen("tcp", addr) + require.NoError(t, err) + + tcpServer := &dns.Server{ + Listener: tcpLn, + Net: "tcp", + Handler: handler, + } + + go func() { + if err := udpServer.ActivateAndServe(); err != nil { + t.Logf("udp server: %v", err) + } + }() + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = udpServer.Shutdown() + _ = tcpServer.Shutdown() + }() + + // Normal UDP exchange without TCP context should succeed via UDP + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err) + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer) + assert.Contains(t, rm.Answer[0].String(), "10.0.0.3") + assert.False(t, rm.Truncated, "small response should not be truncated") +} + +func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) { + // Start only a TCP server (no UDP). With TCP context it should succeed. + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.2"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + tcpServer := &dns.Server{ + Listener: tcpLn, + Net: "tcp", + Handler: tcpHandler, + } + + go func() { + if err := tcpServer.ActivateAndServe(); err != nil { + t.Logf("tcp server: %v", err) + } + }() + defer func() { + _ = tcpServer.Shutdown() + }() + + upstream := tcpLn.Addr().String() + + // TCP context: should skip UDP entirely and go directly to TCP + ctx := contextWithDNSProtocol(context.Background(), protoTCP) + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + + rm, _, err := ExchangeWithFallback(ctx, client, r, upstream) + require.NoError(t, err) + require.NotNil(t, rm) + require.NotEmpty(t, rm.Answer) + assert.Contains(t, rm.Answer[0].String(), "10.0.0.2") + + // Without TCP context, trying to reach a TCP-only server via UDP should fail + ctx2 := context.Background() + client2 := &dns.Client{Timeout: 500 * time.Millisecond} + _, _, err = ExchangeWithFallback(ctx2, client2, r, upstream) + assert.Error(t, err, "should fail when no UDP server and no TCP context") +} diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 5c7cb31fc10..2e8ef84ab59 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re return } - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s", + qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime)) } // udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation. @@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { startTime := time.Now() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": resutil.GenerateRequestID(), "dns_id": fmt.Sprintf("%04x", query.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { startTime := time.Now() - logger := log.WithFields(log.Fields{ + fields := log.Fields{ "request_id": resutil.GenerateRequestID(), "dns_id": fmt.Sprintf("%04x", query.Id), - }) + } + if addr := w.RemoteAddr(); addr != nil { + fields["client"] = addr.String() + } + logger := log.WithFields(fields) f.handleDNSQuery(logger, w, query, startTime) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 7b100bd0cb4..cceb80ff840 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1807,6 +1807,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ WgInterface: e.wgInterface, + Firewall: e.firewall, CustomAddress: e.config.CustomDNSAddress, StatusRecorder: e.statusRecorder, StateManager: e.stateManager, From 491db3baafb98e7ce3ccf263fb10f516be636b96 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 1 Apr 2026 11:55:14 +0200 Subject: [PATCH 2/8] Fix sonar: add comments to empty stubs, extract switch cases --- client/internal/dns/tcpstack.go | 52 ++++++++++++++++----------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go index fc542b1edf3..d278b75f758 100644 --- a/client/internal/dns/tcpstack.go +++ b/client/internal/dns/tcpstack.go @@ -304,14 +304,14 @@ func (e *dnsEndpoint) MTU() uint32 { return e.m func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone } func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 } func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (e *dnsEndpoint) Wait() {} +func (e *dnsEndpoint) Wait() { /* no async work */ } func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } -func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) {} +func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ } func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true } -func (e *dnsEndpoint) Close() {} -func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) {} +func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ } +func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ } func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) } -func (e *dnsEndpoint) SetOnCloseAction(func()) {} +func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ } const tunPacketOffset = 40 @@ -383,8 +383,8 @@ func (w *tcpResponseWriter) Close() error { } func (w *tcpResponseWriter) TsigStatus() error { return nil } -func (w *tcpResponseWriter) TsigTimersOnly(bool) {} -func (w *tcpResponseWriter) Hijack() {} +func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ } +func (w *tcpResponseWriter) Hijack() { /* not supported */ } // readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed). func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) { @@ -418,42 +418,40 @@ func srcAddrFromPacket(pkt []byte) netip.AddrPort { return netip.AddrPort{} } - var srcIP netip.Addr - var transportOffset int + srcIP, transportOffset := srcIPFromPacket(pkt) + if !srcIP.IsValid() || len(pkt) < transportOffset+2 { + return netip.AddrPort{} + } + + srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1]) + return netip.AddrPortFrom(srcIP.Unmap(), srcPort) +} +func srcIPFromPacket(pkt []byte) (netip.Addr, int) { switch header.IPVersion(pkt) { case 4: if len(pkt) < header.IPv4MinimumSize { - return netip.AddrPort{} + return netip.Addr{}, 0 } hdr := header.IPv4(pkt) src := hdr.SourceAddress() - var ok bool - srcIP, ok = netip.AddrFromSlice(src.AsSlice()) + ip, ok := netip.AddrFromSlice(src.AsSlice()) if !ok { - return netip.AddrPort{} + return netip.Addr{}, 0 } - transportOffset = int(hdr.HeaderLength()) + return ip, int(hdr.HeaderLength()) case 6: if len(pkt) < header.IPv6MinimumSize { - return netip.AddrPort{} + return netip.Addr{}, 0 } hdr := header.IPv6(pkt) src := hdr.SourceAddress() - var ok bool - srcIP, ok = netip.AddrFromSlice(src.AsSlice()) + ip, ok := netip.AddrFromSlice(src.AsSlice()) if !ok { - return netip.AddrPort{} + return netip.Addr{}, 0 } - transportOffset = header.IPv6MinimumSize + return ip, header.IPv6MinimumSize default: - return netip.AddrPort{} - } - - // TCP source port is the first 2 bytes of the transport header. - if len(pkt) < transportOffset+2 { - return netip.AddrPort{} + return netip.Addr{}, 0 } - srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1]) - return netip.AddrPortFrom(srcIP.Unmap(), srcPort) } From 9da2fcbf07d57cf5d334c55ac3a5b5ac8741351b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 1 Apr 2026 11:58:30 +0200 Subject: [PATCH 3/8] Extract IPv4/IPv6 parsing into separate functions --- client/internal/dns/tcpstack.go | 48 +++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go index d278b75f758..cc61e92aad2 100644 --- a/client/internal/dns/tcpstack.go +++ b/client/internal/dns/tcpstack.go @@ -430,28 +430,36 @@ func srcAddrFromPacket(pkt []byte) netip.AddrPort { func srcIPFromPacket(pkt []byte) (netip.Addr, int) { switch header.IPVersion(pkt) { case 4: - if len(pkt) < header.IPv4MinimumSize { - return netip.Addr{}, 0 - } - hdr := header.IPv4(pkt) - src := hdr.SourceAddress() - ip, ok := netip.AddrFromSlice(src.AsSlice()) - if !ok { - return netip.Addr{}, 0 - } - return ip, int(hdr.HeaderLength()) + return srcIPv4(pkt) case 6: - if len(pkt) < header.IPv6MinimumSize { - return netip.Addr{}, 0 - } - hdr := header.IPv6(pkt) - src := hdr.SourceAddress() - ip, ok := netip.AddrFromSlice(src.AsSlice()) - if !ok { - return netip.Addr{}, 0 - } - return ip, header.IPv6MinimumSize + return srcIPv6(pkt) default: return netip.Addr{}, 0 } } + +func srcIPv4(pkt []byte) (netip.Addr, int) { + if len(pkt) < header.IPv4MinimumSize { + return netip.Addr{}, 0 + } + hdr := header.IPv4(pkt) + src := hdr.SourceAddress() + ip, ok := netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.Addr{}, 0 + } + return ip, int(hdr.HeaderLength()) +} + +func srcIPv6(pkt []byte) (netip.Addr, int) { + if len(pkt) < header.IPv6MinimumSize { + return netip.Addr{}, 0 + } + hdr := header.IPv6(pkt) + src := hdr.SourceAddress() + ip, ok := netip.AddrFromSlice(src.AsSlice()) + if !ok { + return netip.Addr{}, 0 + } + return ip, header.IPv6MinimumSize +} From 63cf286891734618d58d0adc966f68cd56b85a3a Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 1 Apr 2026 12:25:37 +0200 Subject: [PATCH 4/8] Fix lint, firewall injection, and review feedback - Rename DNSFirewall to Firewall (revive: stuttering name in dns package) - Check Stop() error in service_listener_test (errcheck) - Fix firewall nil at DNS construction: add SetFirewall to Server interface, inject from engine after createFirewall completes - Close managers in hook tests and benchmarks to prevent goroutine leaks - Fix tcpResponseWriter.Write returning wire byte count instead of data - Add nil checks for UDP layer and GetDevice in memory service hook --- client/firewall/iptables/manager_linux.go | 16 ++ client/firewall/iptables/router_linux.go | 78 ++++++++ client/firewall/manager/firewall.go | 6 + client/firewall/nftables/manager_linux.go | 16 ++ client/firewall/nftables/router_linux.go | 125 ++++++++++++ client/firewall/uspfilter/filter_test.go | 2 + client/firewall/uspfilter/hooks_bench_test.go | 182 ------------------ client/firewall/uspfilter/nat.go | 17 ++ client/internal/dns/mock_server.go | 5 + client/internal/dns/server.go | 13 +- client/internal/dns/service.go | 8 +- client/internal/dns/service_listener.go | 12 +- client/internal/dns/service_listener_test.go | 25 ++- client/internal/dns/service_memory.go | 15 +- client/internal/dns/tcpstack.go | 5 +- client/internal/engine.go | 6 +- 16 files changed, 319 insertions(+), 212 deletions(-) delete mode 100644 client/firewall/uspfilter/hooks_bench_test.go diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 04c33837506..2fc6f8ec8dc 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -286,6 +286,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + const ( chainNameRaw = "NETBIRD-RAW" chainOUTPUT = "OUTPUT" diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 1fe4c149f9d..53f6a5aa20f 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -36,6 +36,7 @@ const ( chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainNATOutput = "NETBIRD-NAT-OUTPUT" chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" @@ -43,6 +44,7 @@ const ( jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpNatOutput = "jump-nat-output" jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" @@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error { } log.Debug("flushing routing related tables") + + // Remove jump rules from built-in chains before deleting custom chains, + // otherwise the chain deletion fails with "device or resource busy". + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { + log.Debugf("clean OUTPUT jump rule: %v", err) + } + for _, chainInfo := range []struct { chain string table string @@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainNATOutput, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) @@ -970,6 +981,73 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto return nil } +// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use. +func (r *router) ensureNATOutputChain() error { + if _, exists := r.rules[jumpNatOutput]; exists { + return nil + } + + if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil { + return fmt.Errorf("create chain %s: %w", chainNATOutput, err) + } + + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil { + if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil { + log.Debugf("failed to rollback chain %s: %v", chainNATOutput, delErr) + } + return fmt.Errorf("add OUTPUT jump rule: %w", err) + } + r.rules[jumpNatOutput] = jumpRule + + r.updateState() + return nil +} + +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + if err := r.ensureNATOutputChain(); err != nil { + return err + } + + dnatRule := []string{ + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil { + return fmt.Errorf("add output DNAT rule: %w", err) + } + r.rules[ruleID] = dnatRule + + r.updateState() + return nil +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil { + return fmt.Errorf("delete output DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3511a54630d..ee1eea50647 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -169,6 +169,12 @@ type Manager interface { // RemoveInboundDNAT removes inbound DNAT rule RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. + AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. + RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. // This prevents conntrack from interfering with WireGuard proxy communication. SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index f57b28abc19..beb5b70a79b 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -346,6 +346,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + const ( chainNameRawOutput = "netbird-raw-out" chainNameRawPrerouting = "netbird-raw-pre" diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index fde654c20ca..904daf7cb68 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -36,6 +36,7 @@ const ( chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" + chainNameNATOutput = "netbird-nat-output" chainNameForward = "FORWARD" chainNameMangleForward = "netbird-mangle-forward" @@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto return nil } +// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use. +func (r *router) ensureNATOutputChain() error { + if _, exists := r.chains[chainNameNATOutput]; exists { + return nil + } + + r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameNATOutput, + Table: r.workTable, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityNATDest, + Type: nftables.ChainTypeNAT, + }) + + if err := r.conn.Flush(); err != nil { + delete(r.chains, chainNameNATOutput) + return fmt.Errorf("create NAT output chain: %w", err) + } + return nil +} + +// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + if err := r.ensureNATOutputChain(); err != nil { + return err + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameNATOutput], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add output DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + rule, exists := r.rules[ruleID] + if !exists { + return nil + } + + if rule.Handle == 0 { + log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID) + delete(r.rules, ruleID) + return nil + } + + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete output DNAT rule: %w", err) + } + delete(r.rules, ruleID) + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index cb1fe6f26f4..5f0f9f8602e 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -192,6 +192,7 @@ func TestSetUDPPacketHook(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) var called bool manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool { @@ -215,6 +216,7 @@ func TestSetTCPPacketHook(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) var called bool manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool { diff --git a/client/firewall/uspfilter/hooks_bench_test.go b/client/firewall/uspfilter/hooks_bench_test.go deleted file mode 100644 index be6a8408b30..00000000000 --- a/client/firewall/uspfilter/hooks_bench_test.go +++ /dev/null @@ -1,182 +0,0 @@ -package uspfilter - -import ( - "net" - "net/netip" - "testing" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/stretchr/testify/require" - - nbiface "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/device" -) - -func buildUDPPacket(b *testing.B, srcIP, dstIP string, srcPort, dstPort uint16) []byte { - b.Helper() - - ipLayer := &layers.IPv4{ - Version: 4, - TTL: 64, - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - Protocol: layers.IPProtocolUDP, - } - udpLayer := &layers.UDP{ - SrcPort: layers.UDPPort(srcPort), - DstPort: layers.UDPPort(dstPort), - } - if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { - b.Fatal(err) - } - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - if err := gopacket.SerializeLayers(buf, opts, ipLayer, udpLayer, gopacket.Payload([]byte("test"))); err != nil { - b.Fatal(err) - } - return buf.Bytes() -} - -func buildTCPPacket(b *testing.B, srcIP, dstIP string, srcPort, dstPort uint16) []byte { - b.Helper() - - ipLayer := &layers.IPv4{ - Version: 4, - TTL: 64, - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - Protocol: layers.IPProtocolTCP, - } - tcpLayer := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - SYN: true, - } - if err := tcpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { - b.Fatal(err) - } - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - if err := gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte("test"))); err != nil { - b.Fatal(err) - } - return buf.Bytes() -} - -func newBenchManager(b *testing.B) *Manager { - b.Helper() - m, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger, nbiface.DefaultMTU) - require.NoError(b, err) - return m -} - -// BenchmarkHooksDrop_UDPMatch measures the cost of the UDP hook check when the -// packet matches the registered hook (the DNS interception fast path). -func BenchmarkHooksDrop_UDPMatch(b *testing.B) { - m := newBenchManager(b) - m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) - - pkt := buildUDPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) - - b.ResetTimer() - b.ReportAllocs() - for b.Loop() { - m.udpHooksDrop(53, netip.MustParseAddr("100.10.255.254"), pkt) - } -} - -// BenchmarkHooksDrop_UDPMiss measures the cost when no UDP hook matches -// (common case for non-DNS traffic). -func BenchmarkHooksDrop_UDPMiss(b *testing.B) { - m := newBenchManager(b) - m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) - - pkt := buildUDPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 8080) - - b.ResetTimer() - b.ReportAllocs() - for b.Loop() { - m.udpHooksDrop(8080, netip.MustParseAddr("100.10.0.2"), pkt) - } -} - -// BenchmarkHooksDrop_TCPMatch measures the TCP hook check when matching (DNS TCP). -func BenchmarkHooksDrop_TCPMatch(b *testing.B) { - m := newBenchManager(b) - m.SetTCPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) - - pkt := buildTCPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) - - b.ResetTimer() - b.ReportAllocs() - for b.Loop() { - m.tcpHooksDrop(53, netip.MustParseAddr("100.10.255.254"), pkt) - } -} - -// BenchmarkHooksDrop_TCPMiss measures TCP hook check for non-matching traffic. -func BenchmarkHooksDrop_TCPMiss(b *testing.B) { - m := newBenchManager(b) - m.SetTCPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) - - pkt := buildTCPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 443) - - b.ResetTimer() - b.ReportAllocs() - for b.Loop() { - m.tcpHooksDrop(443, netip.MustParseAddr("100.10.0.2"), pkt) - } -} - -// BenchmarkHooksDrop_NoHooks measures the cost when no hooks are registered -// (the baseline for all non-DNS traffic). -func BenchmarkHooksDrop_NoHooks(b *testing.B) { - m := newBenchManager(b) - - pkt := buildUDPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 8080) - - b.ResetTimer() - b.ReportAllocs() - for b.Loop() { - m.udpHooksDrop(8080, netip.MustParseAddr("100.10.0.2"), pkt) - m.tcpHooksDrop(8080, netip.MustParseAddr("100.10.0.2"), pkt) - } -} - -// BenchmarkFilterOutbound_WithHooks benchmarks the full FilterOutbound path -// with both UDP and TCP hooks registered (the real-world DNS scenario). -func BenchmarkFilterOutbound_WithHooks(b *testing.B) { - m := newBenchManager(b) - m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) - m.SetTCPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool { return true }) - - udpDNS := buildUDPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) - tcpDNS := buildTCPPacket(b, "100.10.0.1", "100.10.255.254", 12345, 53) - tcpHTTPS := buildTCPPacket(b, "100.10.0.1", "100.10.0.2", 12345, 443) - - b.Run("udp_dns_match", func(b *testing.B) { - b.ReportAllocs() - for b.Loop() { - m.FilterOutbound(udpDNS, len(udpDNS)) - } - }) - - b.Run("tcp_dns_match", func(b *testing.B) { - b.ReportAllocs() - for b.Loop() { - m.FilterOutbound(tcpDNS, len(tcpDNS)) - } - }) - - b.Run("tcp_https_miss", func(b *testing.B) { - b.ReportAllocs() - for b.Loop() { - m.FilterOutbound(tcpHTTPS, len(tcpHTTPS)) - } - }) -} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 597f892cf7a..8ed32eb5e2e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +// TODO: also delegate to nativeFirewall when available for kernel WG mode func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType switch protocol { @@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) } +// AddOutputDNAT delegates to the native firewall if available. +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if m.nativeFirewall == nil { + return fmt.Errorf("output DNAT not supported without native firewall") + } + return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveOutputDNAT delegates to the native firewall if available. +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if m.nativeFirewall == nil { + return nil + } + return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) +} + // translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { if !m.portDNATEnabled.Load() { diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 1df57d1db2d..548b1f54f9f 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { // Mock implementation - no-op } +// SetFirewall mock implementation of SetFirewall from Server interface +func (m *MockServer) SetFirewall(Firewall) { + // Mock implementation - no-op +} + // BeginBatch mock implementation of BeginBatch from Server interface func (m *MockServer) BeginBatch() { // Mock implementation - no-op diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index e14824231f4..1d9781855bc 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -58,6 +58,7 @@ type Server interface { UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error SetRouteChecker(func(netip.Addr) bool) + SetFirewall(Firewall) } type nsGroupsByDomain struct { @@ -130,7 +131,6 @@ type registeredHandlerMap map[types.HandlerID]handlerWrapper // DefaultServerConfig holds configuration parameters for NewDefaultServer type DefaultServerConfig struct { WgInterface WGIface - Firewall DNSFirewall CustomAddress string StatusRecorder *peer.Status StateManager *statemanager.Manager @@ -152,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default if config.WgInterface.IsUserspaceBind() { dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(config.WgInterface, addrPort, config.Firewall) + dnsService = newServiceViaListener(config.WgInterface, addrPort, nil) } server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) @@ -375,6 +375,15 @@ func (s *DefaultServer) DnsIP() netip.Addr { return s.service.RuntimeIP() } +// SetFirewall sets the firewall used for DNS port DNAT rules. +// This must be called before Initialize when using the listener-based service, +// because the firewall is typically not available at construction time. +func (s *DefaultServer) SetFirewall(fw Firewall) { + if svc, ok := s.service.(*serviceViaListener); ok { + svc.firewall = fw + } +} + // Stop stops the server func (s *DefaultServer) Stop() { s.probeMu.Lock() diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index bd07110584e..1c6ce7849a4 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -12,12 +12,12 @@ const ( DefaultPort = 53 ) -// DNSFirewall provides DNAT capabilities for DNS port redirection. +// Firewall provides DNAT capabilities for DNS port redirection. // This is used when the DNS server cannot bind port 53 directly // and needs firewall rules to redirect traffic. -type DNSFirewall interface { - AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error - RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error +type Firewall interface { + AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error + RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error } type service interface { diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 315d4f09219..ac8fb87307d 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -41,18 +41,18 @@ type serviceViaListener struct { listenerIsRunning bool listenerFlagLock sync.Mutex ebpfService ebpfMgr.Manager - firewall DNSFirewall + firewall Firewall tcpDNATConfigured bool } -func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, firewall DNSFirewall) *serviceViaListener { +func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener { mux := dns.NewServeMux() s := &serviceViaListener{ wgInterface: wgIface, dnsMux: mux, customAddr: customAddr, - firewall: firewall, + firewall: fw, server: &dns.Server{ Net: "udp", Handler: mux, @@ -105,8 +105,8 @@ func (s *serviceViaListener) Listen() error { // When eBPF redirects UDP port 53 to our listen port, TCP still needs // a DNAT rule because eBPF only handles UDP. if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort { - if err := s.firewall.AddInboundDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { - log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err) } else { s.tcpDNATConfigured = true log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort) @@ -138,7 +138,7 @@ func (s *serviceViaListener) Stop() error { } if s.tcpDNATConfigured && s.firewall != nil { - if err := s.firewall.RemoveInboundDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { + if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) } s.tcpDNATConfigured = false diff --git a/client/internal/dns/service_listener_test.go b/client/internal/dns/service_listener_test.go index 51f52020f19..90ef71d1902 100644 --- a/client/internal/dns/service_listener_test.go +++ b/client/internal/dns/service_listener_test.go @@ -29,45 +29,42 @@ func TestServiceViaListener_TCPAndUDP(t *testing.T) { svc := newServiceViaListener(nil, nil, nil) svc.dnsMux.Handle(".", handler) - // Find a free port by binding and releasing + // Bind both transports up front to avoid TOCTOU races. udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0)) - udpLn, err := net.ListenUDP("udp", udpAddr) + udpConn, err := net.ListenUDP("udp", udpAddr) if err != nil { t.Skip("cannot bind to 127.0.0.153, skipping") } - port := uint16(udpLn.LocalAddr().(*net.UDPAddr).Port) - require.NoError(t, udpLn.Close()) + port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port) - // Check TCP is also available on this port tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port)) tcpLn, err := net.ListenTCP("tcp", tcpAddr) if err != nil { + udpConn.Close() t.Skip("cannot bind TCP on same port, skipping") } - require.NoError(t, tcpLn.Close()) addr := fmt.Sprintf("%s:%d", customIP, port) - svc.server.Addr = addr - svc.tcpServer.Addr = addr + svc.server.PacketConn = udpConn + svc.tcpServer.Listener = tcpLn svc.listenIP = customIP svc.listenPort = port go func() { - if err := svc.server.ListenAndServe(); err != nil { + if err := svc.server.ActivateAndServe(); err != nil { t.Logf("udp server: %v", err) } }() go func() { - if err := svc.tcpServer.ListenAndServe(); err != nil { + if err := svc.tcpServer.ActivateAndServe(); err != nil { t.Logf("tcp server: %v", err) } }() svc.listenerIsRunning = true - defer svc.Stop() - - // Wait for servers to start - time.Sleep(100 * time.Millisecond) + defer func() { + require.NoError(t, svc.Stop()) + }() q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index d6cc494b888..0f3414c56ac 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -121,7 +121,13 @@ func (s *ServiceViaMemory) filterDNSTraffic() error { packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) udpLayer := packet.Layer(layers.LayerTypeUDP) - udp := udpLayer.(*layers.UDP) + if udpLayer == nil { + return true + } + udp, ok := udpLayer.(*layers.UDP) + if !ok { + return true + } msg := new(dns.Msg) if err := msg.Unpack(udp.Payload); err != nil { @@ -129,11 +135,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() error { return true } + dev := s.wgInterface.GetDevice() + if dev == nil { + return true + } + writer := &truncationAwareWriter{ responseWriter: responseWriter{ remote: remoteAddrFromPacket(packet), packet: packet, - device: s.wgInterface.GetDevice().Device, + device: dev.Device, }, tcpDNS: s.tcpDNS, } diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go index cc61e92aad2..ff575926937 100644 --- a/client/internal/dns/tcpstack.go +++ b/client/internal/dns/tcpstack.go @@ -375,7 +375,10 @@ func (w *tcpResponseWriter) Write(data []byte) (int, error) { buf[0] = byte(len(data) >> 8) buf[1] = byte(len(data)) copy(buf[2:], data) - return w.conn.Write(buf) + if _, err := w.conn.Write(buf); err != nil { + return 0, err + } + return len(data), nil } func (w *tcpResponseWriter) Close() error { diff --git a/client/internal/engine.go b/client/internal/engine.go index cceb80ff840..bf2186997f5 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -521,6 +521,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return err } + // Inject firewall into DNS server now that it's available. + // The DNS server is created before the firewall because the route manager + // depends on the DNS server, and the firewall depends on the wg interface. + e.dnsServer.SetFirewall(e.firewall) + e.udpMux, err = e.wgInterface.Up() if err != nil { log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) @@ -1807,7 +1812,6 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ WgInterface: e.wgInterface, - Firewall: e.firewall, CustomAddress: e.config.CustomDNSAddress, StatusRecorder: e.statusRecorder, StateManager: e.stateManager, From 0d8074d13519e97ac909e6da8f17fa5e4a2f7c7c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 1 Apr 2026 15:02:48 +0200 Subject: [PATCH 5/8] Address remaining review feedback - Make ensureNATOutputChain retry-safe by checking ChainExists before NewChain, handling orphaned chains from partial failures - Fix TestExchangeWithFallback_UDPFallbackToTCP to actually test the truncation-to-TCP retry path --- client/firewall/iptables/router_linux.go | 16 +++++++++++---- client/internal/dns/upstream_test.go | 26 ++++++++++++++++-------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 53f6a5aa20f..a7c4f67dd5c 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -987,14 +987,22 @@ func (r *router) ensureNATOutputChain() error { return nil } - if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil { - return fmt.Errorf("create chain %s: %w", chainNATOutput, err) + chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput) + if err != nil { + return fmt.Errorf("check chain %s: %w", chainNATOutput, err) + } + if !chainExists { + if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil { + return fmt.Errorf("create chain %s: %w", chainNATOutput, err) + } } jumpRule := []string{"-j", chainNATOutput} if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil { - if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil { - log.Debugf("failed to rollback chain %s: %v", chainNATOutput, delErr) + if !chainExists { + if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil { + log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr) + } } return fmt.Errorf("add OUTPUT jump rule: %w", err) } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index cb9fdf51786..537f8902f56 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -540,9 +540,18 @@ func TestExchangeWithFallback_TCPContext(t *testing.T) { } func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) { - // Start a server on both UDP and TCP. - // The handler returns a small response that works on both. - handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + // UDP handler returns a truncated response to trigger TCP retry. + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = true + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + // TCP handler returns the full answer. + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Answer = append(m.Answer, &dns.A{ @@ -561,7 +570,7 @@ func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) { udpServer := &dns.Server{ PacketConn: udpPC, Net: "udp", - Handler: handler, + Handler: udpHandler, } tcpLn, err := net.Listen("tcp", addr) @@ -570,7 +579,7 @@ func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) { tcpServer := &dns.Server{ Listener: tcpLn, Net: "tcp", - Handler: handler, + Handler: tcpHandler, } go func() { @@ -588,17 +597,16 @@ func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) { _ = tcpServer.Shutdown() }() - // Normal UDP exchange without TCP context should succeed via UDP ctx := context.Background() client := &dns.Client{Timeout: 2 * time.Second} r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) rm, _, err := ExchangeWithFallback(ctx, client, r, addr) - require.NoError(t, err) + require.NoError(t, err, "should fall back to TCP after truncated UDP response") require.NotNil(t, rm) - require.NotEmpty(t, rm.Answer) + require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer") assert.Contains(t, rm.Answer[0].String(), "10.0.0.3") - assert.False(t, rm.Truncated, "small response should not be truncated") + assert.False(t, rm.Truncated, "TCP response should not be truncated") } func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) { From 45a7839ce3a89795a398598c57250cc74993559e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 2 Apr 2026 14:34:22 +0200 Subject: [PATCH 6/8] Fix DNS listener race conditions and TCP leak from review feedback --- client/internal/dns/server.go | 2 ++ client/internal/dns/service_listener.go | 23 ++++++++++++++--------- client/internal/dns/tcpstack.go | 2 +- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 1d9781855bc..35f6708de07 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -380,7 +380,9 @@ func (s *DefaultServer) DnsIP() netip.Addr { // because the firewall is typically not available at construction time. func (s *DefaultServer) SetFirewall(fw Firewall) { if svc, ok := s.service.(*serviceViaListener); ok { + svc.listenerFlagLock.Lock() svc.firewall = fw + svc.listenerFlagLock.Unlock() } } diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index ac8fb87307d..4e09f1b7f22 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -87,13 +87,23 @@ func (s *serviceViaListener) Listen() error { s.tcpServer.Addr = addr log.Debugf("starting dns on %s (UDP + TCP)", addr) - go func() { - s.setListenerStatus(true) - defer s.setListenerStatus(false) + s.listenerIsRunning = true + go func() { if err := s.server.ListenAndServe(); err != nil { log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err) } + + s.listenerFlagLock.Lock() + unexpected := s.listenerIsRunning + s.listenerIsRunning = false + s.listenerFlagLock.Unlock() + + if unexpected { + if err := s.tcpServer.Shutdown(); err != nil { + log.Debugf("failed to shutdown DNS TCP server: %v", err) + } + } }() go func() { @@ -123,6 +133,7 @@ func (s *serviceViaListener) Stop() error { if !s.listenerIsRunning { return nil } + s.listenerIsRunning = false ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -177,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { return s.listenIP } -func (s *serviceViaListener) setListenerStatus(running bool) { - s.listenerFlagLock.Lock() - defer s.listenerFlagLock.Unlock() - - s.listenerIsRunning = running -} // evalListenAddress figure out the listen address for the DNS server // first check the 53 port availability on WG interface or lo, if not success diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go index ff575926937..17c82fb20ef 100644 --- a/client/internal/dns/tcpstack.go +++ b/client/internal/dns/tcpstack.go @@ -269,7 +269,7 @@ func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) { } for { - if err := conn.SetDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil { + if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil { log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err) break } From 36024a14c42fafe7ed7b1e83bdcfb1f7656e9330 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 2 Apr 2026 16:58:15 +0200 Subject: [PATCH 7/8] Address remaining CodeRabbit review feedback --- client/firewall/manager/firewall.go | 2 ++ client/firewall/uspfilter/filter.go | 2 ++ client/internal/dns/server.go | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index ee1eea50647..d65d717b376 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -170,9 +170,11 @@ type Manager interface { RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. + // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. + // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 6415af17fd0..cb9e1bb0af7 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -605,6 +605,8 @@ func (m *Manager) resetState() { maps.Clear(m.incomingRules) maps.Clear(m.routeRulesMap) m.routeRules = m.routeRules[:0] + m.udpHookOut.Store(nil) + m.tcpHookOut.Store(nil) if m.udpTracker != nil { m.udpTracker.Close() diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 35f6708de07..d4fda5db302 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -407,10 +407,10 @@ func (s *DefaultServer) Stop() { maps.Clear(s.extraDomains) } -func (s *DefaultServer) disableDNS() error { +func (s *DefaultServer) disableDNS() (retErr error) { defer func() { if err := s.service.Stop(); err != nil { - log.Errorf("failed to stop DNS service: %v", err) + retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err)) } }() From c5932ef346e3bd541629ced09da165b33af1b44a Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 2 Apr 2026 21:03:57 +0200 Subject: [PATCH 8/8] Cap EDNS0 to tunnel MTU, simplify TCP DNS stack startup --- client/internal/dns/handler_chain.go | 3 + client/internal/dns/response_writer.go | 16 ---- client/internal/dns/service_memory.go | 11 +-- client/internal/dns/tcpstack.go | 30 +------ client/internal/dns/upstream.go | 67 +++++++++------ client/internal/dns/upstream_test.go | 110 +++++++++++++++++++++++++ 6 files changed, 160 insertions(+), 77 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index f183052cd6a..6fbdedc5953 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { return nil } w.response = m + if m.MsgHdr.Truncated { + w.SetMeta("truncated", "true") + } return w.ResponseWriter.WriteMsg(m) } diff --git a/client/internal/dns/response_writer.go b/client/internal/dns/response_writer.go index 1268b85f5a6..287cf28b032 100644 --- a/client/internal/dns/response_writer.go +++ b/client/internal/dns/response_writer.go @@ -105,22 +105,6 @@ func (r *responseWriter) TsigTimersOnly(bool) { func (r *responseWriter) Hijack() { } -// truncationAwareWriter wraps a UDP responseWriter and starts the TCP DNS -// stack when a truncated response is about to be sent. This ensures the -// TCP stack is ready when the client retries over TCP. -type truncationAwareWriter struct { - responseWriter - tcpDNS *tcpDNSServer -} - -// WriteMsg checks if the response is truncated and starts the TCP stack if needed. -func (w *truncationAwareWriter) WriteMsg(msg *dns.Msg) error { - if msg.MsgHdr.Truncated && w.tcpDNS != nil { - w.tcpDNS.EnsureRunning() - } - return w.responseWriter.WriteMsg(msg) -} - // remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging. func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr { var srcIP net.IP diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 0f3414c56ac..e8c0360766c 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -140,13 +140,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() error { return true } - writer := &truncationAwareWriter{ - responseWriter: responseWriter{ - remote: remoteAddrFromPacket(packet), - packet: packet, - device: dev.Device, - }, - tcpDNS: s.tcpDNS, + writer := &responseWriter{ + remote: remoteAddrFromPacket(packet), + packet: packet, + device: dev.Device, } go s.dnsMux.ServeDNS(writer, msg) return true diff --git a/client/internal/dns/tcpstack.go b/client/internal/dns/tcpstack.go index 17c82fb20ef..88e72e767ca 100644 --- a/client/internal/dns/tcpstack.go +++ b/client/internal/dns/tcpstack.go @@ -59,34 +59,10 @@ func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port u } } -// EnsureRunning starts the TCP stack if not already running and resets the idle timer. -func (t *tcpDNSServer) EnsureRunning() { - t.mu.Lock() - defer t.mu.Unlock() - - if t.closed { - return - } - - if t.running { - t.resetTimerLocked() - return - } - - if err := t.startLocked(); err != nil { - log.Errorf("failed to start TCP DNS stack: %v", err) - return - } - - t.running = true - t.resetTimerLocked() - log.Debugf("TCP DNS stack started on %s:%d", t.ip, t.port) -} - // InjectPacket ensures the stack is running and delivers a raw IP packet into // the gvisor stack for TCP processing. Combining both operations under a single // lock prevents a race where the idle timer could stop the stack between -// EnsureRunning and delivery. +// start and delivery. func (t *tcpDNSServer) InjectPacket(payload []byte) { t.mu.Lock() defer t.mu.Unlock() @@ -118,7 +94,7 @@ func (t *tcpDNSServer) InjectPacket(payload []byte) { } // Stop tears down the gvisor stack and releases resources permanently. -// After Stop, EnsureRunning becomes a no-op. +// After Stop, InjectPacket becomes a no-op. func (t *tcpDNSServer) Stop() { t.mu.Lock() defer t.mu.Unlock() @@ -227,7 +203,7 @@ func (t *tcpDNSServer) resetTimerLocked() { defer t.mu.Unlock() // Only stop if this timer is still the active one. - // A racing EnsureRunning may have replaced it. + // A racing InjectPacket may have replaced it. if t.timerID != id { return } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c61a569a5cc..746b73ca754 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -41,6 +41,10 @@ const ( reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second + + // ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP + // payload from the tunnel MTU. + ipUDPHeaderSize = 60 + 8 ) const testRecord = "com." @@ -489,6 +493,14 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC return err } +// clientUDPMaxSize returns the maximum UDP response size the client accepts. +func clientUDPMaxSize(r *dns.Msg) int { + if opt := r.IsEdns0(); opt != nil { + return int(opt.UDPSize()) + } + return dns.MinMsgSize +} + // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. // If the inbound request came over TCP (via context), it skips the UDP attempt. @@ -506,9 +518,17 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } - // MTU - ip + udp headers - // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling. - client.UDPSize = uint16(currentMTU - (60 + 8)) + clientMaxSize := clientUDPMaxSize(r) + + // Cap EDNS0 to our tunnel MTU so the upstream doesn't send a + // response larger than our read buffer. + // Note: the query could be sent out on an interface that is not ours, + // but higher MTU settings could break truncation handling. + maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize) + client.UDPSize = maxUDPPayload + if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload { + opt.SetUDPSize(maxUDPPayload) + } var ( rm *dns.Msg @@ -531,8 +551,9 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } - log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.", - r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + // TODO: if the upstream's truncated UDP response already contains more + // data than the client's buffer, we could truncate locally and skip + // the TCP retry. tcpClient := *client tcpClient.Net = protoTCP @@ -549,14 +570,8 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u setUpstreamProtocol(ctx, protoTCP) - // Request came in over UDP but response was fetched via TCP. - // Truncate to fit the client's UDP buffer. - maxSize := dns.MinMsgSize - if opt := r.IsEdns0(); opt != nil { - maxSize = int(opt.UDPSize()) - } - if rm.Len() > maxSize { - rm.Truncate(maxSize) + if rm.Len() > clientMaxSize { + rm.Truncate(clientMaxSize) } return rm, t, nil @@ -575,31 +590,29 @@ func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, return rm, nil } + clientMaxSize := clientUDPMaxSize(r) + + // Cap EDNS0 to our tunnel MTU so the upstream doesn't send a + // response larger than what we can read over UDP. + maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize) + if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload { + opt.SetUDPSize(maxUDPPayload) + } + reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP) if err != nil { return nil, err } - // If response is truncated, retry with TCP if reply != nil && reply.MsgHdr.Truncated { - log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP", - r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP) if err != nil { return nil, err } setUpstreamProtocol(ctx, protoTCP) - - // Request came in over UDP but response was fetched via TCP. - // Truncate to fit the client's UDP buffer. - maxSize := dns.MinMsgSize - if opt := r.IsEdns0(); opt != nil { - maxSize = int(opt.UDPSize()) - } - if rm.Len() > maxSize { - rm.Truncate(maxSize) + if rm.Len() > clientMaxSize { + rm.Truncate(clientMaxSize) } return rm, nil @@ -627,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst } } - dnsConn := &dns.Conn{Conn: conn} + dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)} if err := dnsConn.WriteMsg(r); err != nil { return nil, fmt.Errorf("write %s message: %w", network, err) diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 537f8902f56..1797fdad81d 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -660,3 +660,113 @@ func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) { _, _, err = ExchangeWithFallback(ctx2, client2, r, upstream) assert.Error(t, err, "should fail when no UDP server and no TCP context") } + +func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { + // Verify that a client EDNS0 larger than our MTU-derived limit gets + // capped in the outgoing request so the upstream doesn't send a + // response larger than our read buffer. + var receivedUDPSize uint16 + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + if opt := r.IsEdns0(); opt != nil { + receivedUDPSize = opt.UDPSize() + } + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1"), + }) + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler} + go func() { _ = udpServer.ActivateAndServe() }() + t.Cleanup(func() { _ = udpServer.Shutdown() }) + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + r.SetEdns0(4096, false) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err) + require.NotNil(t, rm) + + expectedMax := uint16(currentMTU - ipUDPHeaderSize) + assert.Equal(t, expectedMax, receivedUDPSize, + "upstream should see capped EDNS0, not the client's 4096") +} + +func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) { + // When the client advertises a large EDNS0 (4096) and the upstream + // truncates, the TCP response should NOT be truncated since the full + // answer fits within the client's original buffer. + udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = true + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + // Add enough records to exceed MTU but fit within 4096 + for i := range 20 { + m.Answer = append(m.Answer, &dns.TXT{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60}, + Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)}, + }) + } + if err := w.WriteMsg(m); err != nil { + t.Logf("write msg: %v", err) + } + }) + + udpPC, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + addr := udpPC.LocalAddr().String() + + udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler} + tcpLn, err := net.Listen("tcp", addr) + require.NoError(t, err) + tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler} + + go func() { _ = udpServer.ActivateAndServe() }() + go func() { _ = tcpServer.ActivateAndServe() }() + t.Cleanup(func() { + _ = udpServer.Shutdown() + _ = tcpServer.Shutdown() + }) + + ctx := context.Background() + client := &dns.Client{Timeout: 2 * time.Second} + + // Client with large buffer: should get all records without truncation + r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + r.SetEdns0(4096, false) + + rm, _, err := ExchangeWithFallback(ctx, client, r, addr) + require.NoError(t, err) + require.NotNil(t, rm) + assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records") + assert.False(t, rm.Truncated, "response should not be truncated for large buffer client") + + // Client with small buffer: should get truncated response + r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + r2.SetEdns0(512, false) + + rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr) + require.NoError(t, err) + require.NotNil(t, rm2) + assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records") + assert.True(t, rm2.Truncated, "response should be truncated for small buffer client") +}