diff --git a/p2p/host/resource-manager/conn_limiter.go b/p2p/host/resource-manager/conn_limiter.go index 717249e854..73c288f077 100644 --- a/p2p/host/resource-manager/conn_limiter.go +++ b/p2p/host/resource-manager/conn_limiter.go @@ -114,8 +114,8 @@ type connLimiter struct { // Subnet limits. connLimitPerSubnetV4 []ConnLimitPerSubnet connLimitPerSubnetV6 []ConnLimitPerSubnet - ip4connsPerLimit []map[string]int - ip6connsPerLimit []map[string]int + ip4connsPerLimit []map[netip.Prefix]int + ip6connsPerLimit []map[netip.Prefix]int } func newConnLimiter() *connLimiter { @@ -180,7 +180,7 @@ func (cl *connLimiter) addConn(ip netip.Addr) bool { } if len(connsPerLimit) == 0 && len(limits) > 0 { - connsPerLimit = make([]map[string]int, len(limits)) + connsPerLimit = make([]map[netip.Prefix]int, len(limits)) if isIP6 { cl.ip6connsPerLimit = connsPerLimit } else { @@ -193,13 +193,12 @@ func (cl *connLimiter) addConn(ip netip.Addr) bool { if err != nil { return false } - masked := prefix.String() - counts, ok := connsPerLimit[i][masked] + counts, ok := connsPerLimit[i][prefix] if !ok { if connsPerLimit[i] == nil { - connsPerLimit[i] = make(map[string]int) + connsPerLimit[i] = make(map[netip.Prefix]int) } - connsPerLimit[i][masked] = 0 + connsPerLimit[i][prefix] = 0 } if counts+1 > limit.ConnCount { return false @@ -209,8 +208,7 @@ func (cl *connLimiter) addConn(ip netip.Addr) bool { // All limit checks passed, now we update the counts for i, limit := range limits { prefix, _ := ip.Prefix(limit.PrefixLength) - masked := prefix.String() - connsPerLimit[i][masked]++ + connsPerLimit[i][prefix]++ } return true @@ -258,7 +256,7 @@ func (cl *connLimiter) rmConn(ip netip.Addr) { if len(connsPerLimit) == 0 && len(limits) > 0 { // Initialize just in case. We should have already initialized in // addConn, but if the callers calls rmConn first we don't want to panic - connsPerLimit = make([]map[string]int, len(limits)) + connsPerLimit = make([]map[netip.Prefix]int, len(limits)) if isIP6 { cl.ip6connsPerLimit = connsPerLimit } else { @@ -273,16 +271,15 @@ func (cl *connLimiter) rmConn(ip netip.Addr) { log.Errorf("unexpected error getting prefix: %v", err) continue } - masked := prefix.String() - counts, ok := connsPerLimit[i][masked] + counts, ok := connsPerLimit[i][prefix] if !ok || counts == 0 { // Unexpected, but don't panic - log.Errorf("unexpected conn count for %s ok=%v count=%v", masked, ok, counts) + log.Errorf("unexpected conn count for %s ok=%v count=%v", prefix, ok, counts) continue } - connsPerLimit[i][masked]-- - if connsPerLimit[i][masked] <= 0 { - delete(connsPerLimit[i], masked) + connsPerLimit[i][prefix]-- + if connsPerLimit[i][prefix] <= 0 { + delete(connsPerLimit[i], prefix) } } } diff --git a/p2p/host/resource-manager/conn_limiter_test.go b/p2p/host/resource-manager/conn_limiter_test.go index 89ce5d53c6..5c06b65add 100644 --- a/p2p/host/resource-manager/conn_limiter_test.go +++ b/p2p/host/resource-manager/conn_limiter_test.go @@ -24,6 +24,22 @@ func TestItLimits(t *testing.T) { require.NoError(t, err) require.True(t, cl.addConn(otherIP)) }) + + t.Run("IPv4 removal", func(t *testing.T) { + ip, err := netip.ParseAddr("1.2.3.4") + require.NoError(t, err) + cl := newConnLimiter() + cl.connLimitPerSubnetV4[0].ConnCount = 1 + require.True(t, cl.addConn(ip)) + + // should fail the second time + require.False(t, cl.addConn(ip)) + // remove the connection + cl.rmConn(ip) + // should succeed now + require.True(t, cl.addConn(ip)) + }) + t.Run("IPv6", func(t *testing.T) { ip, err := netip.ParseAddr("1:2:3:4::1") require.NoError(t, err)