diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index b6e0cf5b277..fde654c20ca 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { } if nftRule.Handle == 0 { - return fmt.Errorf("route rule %s has no handle", ruleKey) + log.Warnf("route rule %s has no handle, removing stale entry", ruleKey) + if err := r.decrementSetCounter(nftRule); err != nil { + log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err) + } + delete(r.rules, ruleKey) + return nil } if err := r.deleteNftRule(nftRule, ruleKey); err != nil { @@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { } if err := r.conn.Flush(); err != nil { - // TODO: rollback ipset counter - return fmt.Errorf("insert rules for %s: %v", pair.Destination, err) + r.rollbackRules(pair) + return fmt.Errorf("insert rules for %s: %w", pair.Destination, err) } return nil } +// rollbackRules cleans up unflushed rules and their set counters after a flush failure. +func (r *router) rollbackRules(pair firewall.RouterPair) { + keys := []string{ + firewall.GenKey(firewall.ForwardingFormat, pair), + firewall.GenKey(firewall.PreroutingFormat, pair), + firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)), + } + for _, key := range keys { + rule, ok := r.rules[key] + if !ok { + continue + } + if err := r.decrementSetCounter(rule); err != nil { + log.Warnf("rollback set counter for %s: %v", key, err) + } + delete(r.rules, key) + } +} + // addNatRule inserts a nftables rule to the conn client flush queue func (r *router) addNatRule(pair firewall.RouterPair) error { sourceExp, err := r.applyNetwork(pair.Source, nil, true) @@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) - if rule, exists := r.rules[ruleKey]; exists { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + rule, exists := r.rules[ruleKey] + if !exists { + return nil + } + + if rule.Handle == 0 { + log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey) + if err := r.decrementSetCounter(rule); err != nil { + log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err) } + delete(r.rules, ruleKey) + return nil + } - log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err) + } - delete(r.rules, ruleKey) + log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) - if err := r.decrementSetCounter(rule); err != nil { - return fmt.Errorf("decrement set counter: %w", err) - } + delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) } return nil @@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf(refreshRulesMapError, err) } + var merr *multierror.Error + if pair.Masquerade { if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove prerouting rule: %w", err) + merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err)) } if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse prerouting rule: %w", err) + merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err)) } } if err := r.removeLegacyRouteRule(pair); err != nil { - return fmt.Errorf("remove legacy routing rule: %w", err) + merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err)) } + // Set counters are decremented in the sub-methods above before flush. If flush fails, + // counters will be off until the next successful removal or refresh cycle. if err := r.conn.Flush(); err != nil { - // TODO: rollback set counter - return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err) + merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err)) } - return nil + return nberrors.FormatErrorOrNil(merr) } func (r *router) removeNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) - if rule, exists := r.rules[ruleKey]; exists { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) + rule, exists := r.rules[ruleKey] + if !exists { + log.Debugf("prerouting rule %s not found", ruleKey) + return nil + } + + if rule.Handle == 0 { + log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey) + if err := r.decrementSetCounter(rule); err != nil { + log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err) } + delete(r.rules, ruleKey) + return nil + } - log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination) + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err) + } - delete(r.rules, ruleKey) + log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination) - if err := r.decrementSetCounter(rule); err != nil { - return fmt.Errorf("decrement set counter: %w", err) - } - } else { - log.Debugf("prerouting rule %s not found", ruleKey) + delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) } return nil } -// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid -// duplicates and to get missing attributes that we don't have when adding new rules +// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries +// (e.g. from failed flushes) and updates handles for all existing rules. func (r *router) refreshRulesMap() error { + var merr *multierror.Error + newRules := make(map[string]*nftables.Rule) for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf("list rules: %w", err) + merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err)) + // preserve existing entries for this chain since we can't verify their state + for k, v := range r.rules { + if v.Chain != nil && v.Chain.Name == chain.Name { + newRules[k] = v + } + } + continue } for _, rule := range rules { if len(rule.UserData) > 0 { - r.rules[string(rule.UserData)] = rule + newRules[string(rule.UserData)] = rule } } } - return nil + r.rules = newRules + return nberrors.FormatErrorOrNil(merr) } func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { @@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { } var merr *multierror.Error + var needsFlush bool + if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists { - if err := r.conn.DelRule(dnatRule); err != nil { + if dnatRule.Handle == 0 { + log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix) + delete(r.rules, ruleKey+dnatSuffix) + } else if err := r.conn.DelRule(dnatRule); err != nil { merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err)) + } else { + needsFlush = true } } if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists { - if err := r.conn.DelRule(masqRule); err != nil { + if masqRule.Handle == 0 { + log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix) + delete(r.rules, ruleKey+snatSuffix) + } else if err := r.conn.DelRule(masqRule); err != nil { merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err)) + } else { + needsFlush = true } } - if err := r.conn.Flush(); err != nil { - merr = multierror.Append(merr, fmt.Errorf(flushError, err)) + if needsFlush { + if err := r.conn.Flush(); err != nil { + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) + } } if merr == nil { @@ -1757,15 +1831,24 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) - if rule, exists := r.rules[ruleID]; exists { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) - } - if err := r.conn.Flush(); err != nil { - return fmt.Errorf("flush delete inbound DNAT rule: %w", err) - } + rule, exists := r.rules[ruleID] + if !exists { + return nil + } + + if rule.Handle == 0 { + log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID) delete(r.rules, ruleID) + return nil + } + + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete inbound DNAT rule: %w", err) } + delete(r.rules, ruleID) return nil } diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 3531b014be7..f0e34d211f5 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -18,6 +18,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/acl/id" ) const ( @@ -719,3 +720,137 @@ func deleteWorkTable() { } } } + +func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err) + defer deleteWorkTable() + + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) + require.NoError(t, err) + require.NoError(t, r.init(workTable)) + defer func() { require.NoError(t, r.Reset()) }() + + // Add a real rule to the kernel + ruleKey, err := r.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")}, + firewall.ProtocolTCP, + nil, + &firewall.Port{Values: []uint16{80}}, + firewall.ActionAccept, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, r.DeleteRouteRule(ruleKey)) + }) + + // Inject a stale entry with Handle=0 (simulates store-before-flush failure) + staleKey := "stale-rule-that-does-not-exist" + r.rules[staleKey] = &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingFw], + Handle: 0, + UserData: []byte(staleKey), + } + + require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh") + + err = r.refreshRulesMap() + require.NoError(t, err) + + assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh") + + realRule, ok := r.rules[ruleKey.ID()] + assert.True(t, ok, "real rule should still exist after refresh") + assert.NotZero(t, realRule.Handle, "real rule should have a valid handle") +} + +func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err) + defer deleteWorkTable() + + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) + require.NoError(t, err) + require.NoError(t, r.init(workTable)) + defer func() { require.NoError(t, r.Reset()) }() + + // Inject a stale entry with Handle=0 + staleKey := "stale-route-rule" + r.rules[staleKey] = &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingFw], + Handle: 0, + UserData: []byte(staleKey), + } + + // DeleteRouteRule should not return an error for stale handles + err = r.DeleteRouteRule(id.RuleID(staleKey)) + assert.NoError(t, err, "deleting a stale rule should not error") + assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up") +} + +func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + manager, err := Create(ifaceMock, iface.DefaultMTU) + require.NoError(t, err) + require.NoError(t, manager.Init(nil)) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + pair := firewall.RouterPair{ + ID: "staletest", + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, + Masquerade: true, + } + + rtr := manager.router + + // First add succeeds + err = rtr.AddNatRule(pair) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, rtr.RemoveNatRule(pair)) + }) + + // Corrupt the handle to simulate stale state + natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair) + if rule, exists := rtr.rules[natRuleKey]; exists { + rule.Handle = 0 + } + inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)) + if rule, exists := rtr.rules[inverseKey]; exists { + rule.Handle = 0 + } + + // Adding the same rule again should succeed despite stale handles + err = rtr.AddNatRule(pair) + assert.NoError(t, err, "AddNatRule should succeed even with stale entries") + + // Verify rules exist in kernel + rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting]) + require.NoError(t, err) + + found := 0 + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + found++ + } + } + assert.Equal(t, 1, found, "NAT rule should exist in kernel") +}