From 77eba5739e684155fe2b2225e22ecab5dc841a01 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 21 Mar 2025 23:02:03 +0530 Subject: [PATCH] autonatv2: fix server dial data request policy The policy was comparing the connection local addr to the observed addr. This should be comparing the connection remote addr to the requested dial addr. The impact here is: if we refresh reachability for addresses every hour, we will be spending 100kB per address per hour. That's equivalent to 30 B/s. For 10 addrs this will be 300 B/s or 3kb/s --- p2p/protocol/autonatv2/server.go | 12 ++--- p2p/protocol/autonatv2/server_test.go | 76 +++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index 05e0bdd9fd..dafac4dfe2 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -30,7 +30,7 @@ var ( errDialDataRefused = errors.New("dial data refused") ) -type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool +type dataRequestPolicyFunc = func(observedAddr, dialAddr ma.Multiaddr) bool type EventDialRequestCompleted struct { Error error @@ -212,7 +212,7 @@ func (as *server) serveDialRequest(s network.Stream) EventDialRequestCompleted { nonce := msg.GetDialRequest().Nonce - isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr) + isDialDataRequired := as.dialDataRequestPolicy(s.Conn().RemoteMultiaddr(), dialAddr) if isDialDataRequired && !as.limiter.AcceptDialDataRequest(p) { msg = pb.Message{ Msg: &pb.Message_DialResponse{ @@ -517,11 +517,11 @@ func (r *rateLimiter) Close() { // amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed // IP address is different from the dial back IP address -func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool { - connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr()) +func amplificationAttackPrevention(observedAddr, dialAddr ma.Multiaddr) bool { + observedIP, err := manet.ToIP(observedAddr) if err != nil { return true } - dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr - return !connIP.Equal(dialIP) + dialIP, _ := manet.ToIP(dialAddr) // must be an IP multiaddr + return !observedIP.Equal(dialIP) } diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 6664f1c397..f92396fb16 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -143,7 +143,7 @@ func TestServerDataRequest(t *testing.T) { dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) // ask for dial data for quic address an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { return true } @@ -197,7 +197,7 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { doneChan := make(chan struct{}) an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( // stall all allowed requests - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { <-doneChan return true }), @@ -255,7 +255,7 @@ func TestServerDataRequestJitter(t *testing.T) { dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) // ask for dial data for quic address an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { return true } @@ -520,6 +520,76 @@ func TestReadDialData(t *testing.T) { } } +func TestServerDataRequestWithAmplificationAttackPrevention(t *testing.T) { + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowPrivateAddrs, + WithServerRateLimit(10, 10, 10, 2), + withAmplificationAttackPreventionDialWait(0), + ) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowPrivateAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + err := c.host.Network().Listen(ma.StringCast("/ip6/::1/udp/0/quic-v1")) + if err != nil { + // machine doesn't have ipv6 + t.Skip("skipping test because machine doesn't have ipv6") + } + + var quicv4Addr ma.Multiaddr + var quicv6Addr ma.Multiaddr + for _, a := range c.host.Addrs() { + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + if _, err := a.ValueForProtocol(ma.P_IP4); err == nil { + quicv4Addr = a + } else { + quicv6Addr = a + } + } + } + res, err := c.GetReachability(context.Background(), []Request{{Addr: quicv4Addr, SendDialData: false}}) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: quicv4Addr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + + // ipv6 address should require dial data + _, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}}) + require.Error(t, err) + require.ErrorContains(t, err, "invalid dial data request: low priority addr") + + // ipv6 address should work fine with dial data + res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}}) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: quicv6Addr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) +} + +func TestDefaultAmplificationAttackPrevention(t *testing.T) { + q1 := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") + q2 := ma.StringCast("/ip4/1.2.3.4/udp/1235/quic-v1") + t1 := ma.StringCast("/ip4/1.2.3.4/tcp/1234") + + require.False(t, amplificationAttackPrevention(q1, q1)) + require.False(t, amplificationAttackPrevention(q1, q2)) + require.False(t, amplificationAttackPrevention(q1, t1)) + + t2 := ma.StringCast("/ip4/1.1.1.1/tcp/1235") // different IP + require.True(t, amplificationAttackPrevention(q2, t2)) +} + func FuzzServerDialRequest(f *testing.F) { a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2)) c := newAutoNAT(f, nil)