Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions client/firewall/nftables/router_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
log.Warnf("failed to load filter table, skipping accept rules: %v", err)
}

return r, nil
Expand Down Expand Up @@ -175,7 +171,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("unable to list tables: %v", err)
return nil, fmt.Errorf("list tables: %w", err)
}

for _, table := range tables {
Expand All @@ -193,8 +189,6 @@ func (r *router) createContainers() error {
Table: r.workTable,
})

insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])

prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Expand Down Expand Up @@ -236,9 +230,12 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter,
})

// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])

r.addPostroutingRules()

if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}

if err := r.addMSSClampingRules(); err != nil {
Expand All @@ -250,11 +247,7 @@ func (r *router) createContainers() error {
}

if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}

if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
log.Errorf("failed to refresh rules: %s", err)
}

return nil
Expand Down Expand Up @@ -695,7 +688,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}

// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
func (r *router) addPostroutingRules() {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
Expand Down Expand Up @@ -761,8 +754,6 @@ func (r *router) addPostroutingRules() error {
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})

return nil
}

// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
Expand Down Expand Up @@ -839,7 +830,7 @@ func (r *router) addMSSClampingRules() error {
Exprs: exprsOut,
})

return nil
return r.conn.Flush()
}

// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
Expand Down Expand Up @@ -1068,7 +1059,7 @@ func (r *router) acceptFilterRulesNftables() error {
}
r.conn.InsertRule(inputRule)

return nil
return r.conn.Flush()
}

func (r *router) removeAcceptFilterRules() error {
Expand Down Expand Up @@ -1196,7 +1187,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf(" unable to list rules: %v", err)
return fmt.Errorf("list rules: %w", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
Expand Down
Loading