diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index 05e0bdd9fd..dafac4dfe2 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -30,7 +30,7 @@ var ( errDialDataRefused = errors.New("dial data refused") ) -type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool +type dataRequestPolicyFunc = func(observedAddr, dialAddr ma.Multiaddr) bool type EventDialRequestCompleted struct { Error error @@ -212,7 +212,7 @@ func (as *server) serveDialRequest(s network.Stream) EventDialRequestCompleted { nonce := msg.GetDialRequest().Nonce - isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr) + isDialDataRequired := as.dialDataRequestPolicy(s.Conn().RemoteMultiaddr(), dialAddr) if isDialDataRequired && !as.limiter.AcceptDialDataRequest(p) { msg = pb.Message{ Msg: &pb.Message_DialResponse{ @@ -517,11 +517,11 @@ func (r *rateLimiter) Close() { // amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed // IP address is different from the dial back IP address -func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool { - connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr()) +func amplificationAttackPrevention(observedAddr, dialAddr ma.Multiaddr) bool { + observedIP, err := manet.ToIP(observedAddr) if err != nil { return true } - dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr - return !connIP.Equal(dialIP) + dialIP, _ := manet.ToIP(dialAddr) // must be an IP multiaddr + return !observedIP.Equal(dialIP) } diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 6664f1c397..f92396fb16 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -143,7 +143,7 @@ func TestServerDataRequest(t *testing.T) { dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) // ask for dial data for quic address an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { return true } @@ -197,7 +197,7 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { doneChan := make(chan struct{}) an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( // stall all allowed requests - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { <-doneChan return true }), @@ -255,7 +255,7 @@ func TestServerDataRequestJitter(t *testing.T) { dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) // ask for dial data for quic address an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { return true } @@ -520,6 +520,76 @@ func TestReadDialData(t *testing.T) { } } +func TestServerDataRequestWithAmplificationAttackPrevention(t *testing.T) { + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowPrivateAddrs, + WithServerRateLimit(10, 10, 10, 2), + withAmplificationAttackPreventionDialWait(0), + ) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowPrivateAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + err := c.host.Network().Listen(ma.StringCast("/ip6/::1/udp/0/quic-v1")) + if err != nil { + // machine doesn't have ipv6 + t.Skip("skipping test because machine doesn't have ipv6") + } + + var quicv4Addr ma.Multiaddr + var quicv6Addr ma.Multiaddr + for _, a := range c.host.Addrs() { + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + if _, err := a.ValueForProtocol(ma.P_IP4); err == nil { + quicv4Addr = a + } else { + quicv6Addr = a + } + } + } + res, err := c.GetReachability(context.Background(), []Request{{Addr: quicv4Addr, SendDialData: false}}) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: quicv4Addr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + + // ipv6 address should require dial data + _, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}}) + require.Error(t, err) + require.ErrorContains(t, err, "invalid dial data request: low priority addr") + + // ipv6 address should work fine with dial data + res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}}) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: quicv6Addr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) +} + +func TestDefaultAmplificationAttackPrevention(t *testing.T) { + q1 := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") + q2 := ma.StringCast("/ip4/1.2.3.4/udp/1235/quic-v1") + t1 := ma.StringCast("/ip4/1.2.3.4/tcp/1234") + + require.False(t, amplificationAttackPrevention(q1, q1)) + require.False(t, amplificationAttackPrevention(q1, q2)) + require.False(t, amplificationAttackPrevention(q1, t1)) + + t2 := ma.StringCast("/ip4/1.1.1.1/tcp/1235") // different IP + require.True(t, amplificationAttackPrevention(q2, t2)) +} + func FuzzServerDialRequest(f *testing.F) { a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2)) c := newAutoNAT(f, nil)