Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions client/firewall/uspfilter/allow_netbird.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
package uspfilter

import (
"context"
"net/netip"
"time"

log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/internal/statemanager"
)

Expand All @@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()

m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
}

if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}

if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()

if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)
Expand Down
31 changes: 1 addition & 30 deletions client/firewall/uspfilter/allow_netbird_windows.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package uspfilter

import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"

log "github.com/sirupsen/logrus"

Expand All @@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()

m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
}

if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}

if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()

if !isWindowsFirewallReachable() {
return nil
Expand Down
60 changes: 55 additions & 5 deletions client/firewall/uspfilter/filter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package uspfilter

import (
"context"
"encoding/binary"
"errors"
"fmt"
Expand All @@ -12,18 +13,21 @@ import (
"strings"
"sync"
"sync/atomic"
"time"

"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"

firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
Expand Down Expand Up @@ -89,6 +93,7 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
Expand Down Expand Up @@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
Expand Down Expand Up @@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}

ruleID := uuid.New().String()
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)

if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}

rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
id: string(ruleKey),
mgmtId: id,
sources: sources,
dstSet: destination.Set,
Expand All @@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(

m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule

return &rule, nil
}
Expand All @@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}

ruleID := rule.ID()
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}

idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID
return r.id == string(ruleKey)
})
if idx < 0 {
return fmt.Errorf("route rule not found: %s", ruleID)
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
}

m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}

Expand Down Expand Up @@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }

// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]

if m.udpTracker != nil {
m.udpTracker.Close()
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
}

if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}

if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}

// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
Expand Down
Loading
Loading