Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions client/firewall/create_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}

// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
Comment thread
lixmal marked this conversation as resolved.

return fm, nil
}

Expand Down
37 changes: 37 additions & 0 deletions client/firewall/uspfilter/common/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package common

import (
"net/netip"
"sync/atomic"
)

// PacketHook stores a registered hook for a specific IP:port.
type PacketHook struct {
IP netip.Addr
Port uint16
Fn func([]byte) bool
}

// HookMatches checks if a packet's destination matches the hook and invokes it.
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.IP == dstIP && h.Port == dport {
return h.Fn(packetData)
}
return false
}

// SetHook atomically stores a hook, handling nil removal.
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
if hook == nil {
ptr.Store(nil)
return
}
ptr.Store(&PacketHook{
IP: ip,
Port: dPort,
Fn: hook,
})
}
45 changes: 6 additions & 39 deletions client/firewall/uspfilter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,8 @@ type Manager struct {
mssClampEnabled bool

// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}

// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
udpHookOut atomic.Pointer[common.PacketHook]
tcpHookOut atomic.Pointer[common.PacketHook]
}

// decoder for packages
Expand Down Expand Up @@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
}

func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}

func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}

func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
return false
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}

// filterInbound implements filtering logic for incoming packets.
Expand Down Expand Up @@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot

// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
common.SetHook(&m.udpHookOut, ip, dPort, hook)
}

// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
}

// SetLogLevel sets the log level for the firewall manager
Expand Down
12 changes: 6 additions & 6 deletions client/firewall/uspfilter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {

h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(8000), h.Port)
assert.True(t, h.Fn(nil))
assert.True(t, called)

manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
Expand All @@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {

h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(53), h.Port)
assert.True(t, h.Fn(nil))
assert.True(t, called)

manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
Expand Down
90 changes: 90 additions & 0 deletions client/firewall/uspfilter/hooks_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package uspfilter

import (
"encoding/binary"
"net/netip"
"sync/atomic"

"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/iface/device"
)

const (
ipv4HeaderMinLen = 20
ipv4ProtoOffset = 9
ipv4FlagsOffset = 6
ipv4DstOffset = 16
ipProtoUDP = 17
ipProtoTCP = 6
ipv4FragOffMask = 0x1fff
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
dstPortOffset = 2
)

// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
// It is installed on the WireGuard interface when the userspace bind is active
// but a full firewall filter (Manager) is not needed because a native kernel
// firewall (nftables/iptables) handles packet filtering.
type HooksFilter struct {
udpHook atomic.Pointer[common.PacketHook]
tcpHook atomic.Pointer[common.PacketHook]
}

var _ device.PacketFilter = (*HooksFilter)(nil)

// FilterOutbound checks outbound packets for DNS hook matches.
// Only IPv4 packets matching the registered hook IP:port are intercepted.
// IPv6 and non-IP packets pass through unconditionally.
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
if len(packetData) < ipv4HeaderMinLen {
return false
}

// Only process IPv4 packets, let everything else pass through.
if packetData[0]>>4 != 4 {
return false
}

ihl := int(packetData[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
return false
}

// Skip non-first fragments: they don't carry L4 headers.
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
if flagsAndOffset&ipv4FragOffMask != 0 {
return false
}

dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
if !ok {
return false
}

proto := packetData[ipv4ProtoOffset]
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
Comment thread
coderabbitai[bot] marked this conversation as resolved.

switch proto {
case ipProtoUDP:
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
case ipProtoTCP:
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
default:
return false
}
}

// FilterInbound allows all inbound packets (native firewall handles filtering).
func (f *HooksFilter) FilterInbound([]byte, int) bool {
return false
}

// SetUDPPacketHook registers the UDP packet hook.
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.udpHook, ip, dPort, hook)
}

// SetTCPPacketHook registers the TCP packet hook.
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.tcpHook, ip, dPort, hook)
}
Loading