From 6ec6085a6b028d795eb909f5818af3f7facb8f46 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 31 Jul 2024 13:53:28 -0700 Subject: [PATCH] Infer addresses that share the same local thin waist --- p2p/protocol/identify/obsaddr.go | 32 ++++++ p2p/protocol/identify/obsaddr_test.go | 134 ++++++++++++++++---------- 2 files changed, 117 insertions(+), 49 deletions(-) diff --git a/p2p/protocol/identify/obsaddr.go b/p2p/protocol/identify/obsaddr.go index fc1c100c8e..7cb08db11d 100644 --- a/p2p/protocol/identify/obsaddr.go +++ b/p2p/protocol/identify/obsaddr.go @@ -213,6 +213,36 @@ func (o *ObservedAddrManager) AddrsFor(addr ma.Multiaddr) (addrs []ma.Multiaddr) return res } +// appendInferredAddrs infers the external address of other transports that +// share the local thin waist with a transport that we have do observations for. +// +// e.g. If we have observations for a QUIC address on port 9000, and we are +// listening on the same interface and port 9000 for WebTransport, we can infer +// the external WebTransport address. +func (o *ObservedAddrManager) appendInferredAddrs(twToObserverSets map[string][]*observerSet, addrs []ma.Multiaddr) []ma.Multiaddr { + if twToObserverSets == nil { + twToObserverSets = make(map[string][]*observerSet) + for localTWStr := range o.externalAddrs { + twToObserverSets[localTWStr] = append(twToObserverSets[localTWStr], o.getTopExternalAddrs(localTWStr)...) + } + } + for _, a := range o.listenAddrs() { + if _, ok := o.localAddrs[string(a.Bytes())]; ok { + // We already have this address in the list + continue + } + a = o.normalize(a) + t, err := thinWaistForm(a) + if err != nil { + continue + } + for _, s := range twToObserverSets[string(t.TW.Bytes())] { + addrs = append(addrs, s.cacheMultiaddr(t.Rest)) + } + } + return addrs +} + // Addrs return all activated observed addresses func (o *ObservedAddrManager) Addrs() []ma.Multiaddr { o.mu.RLock() @@ -228,6 +258,8 @@ func (o *ObservedAddrManager) Addrs() []ma.Multiaddr { addrs = append(addrs, s.cacheMultiaddr(t.Rest)) } } + + addrs = o.appendInferredAddrs(m, addrs) return addrs } diff --git a/p2p/protocol/identify/obsaddr_test.go b/p2p/protocol/identify/obsaddr_test.go index 94366f882e..5087c935ad 100644 --- a/p2p/protocol/identify/obsaddr_test.go +++ b/p2p/protocol/identify/obsaddr_test.go @@ -15,6 +15,7 @@ import ( ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -35,35 +36,33 @@ func normalize(addr ma.Multiaddr) ma.Multiaddr { } } -func addrsEqual(a, b []ma.Multiaddr) bool { - if len(b) != len(a) { - return false +// subtractFrom takes the difference between two slices of multiaddrs (a - b) +func subtractFrom(a, b []ma.Multiaddr) []string { + bSet := make(map[string]struct{}, len(b)) + for _, addr := range b { + bSet[string(addr.Bytes())] = struct{}{} } - for _, x := range b { - found := false - for _, y := range a { - if y.Equal(x) { - found = true - break - } - } - if !found { - return false + out := make([]string, 0, len(a)) + for _, addr := range a { + if _, ok := bSet[string(addr.Bytes())]; !ok { + out = append(out, addr.String()) } } - for _, x := range a { - found := false - for _, y := range b { - if y.Equal(x) { - found = true - break - } - } - if !found { - return false - } + return out +} + +func multiaddrsToStrings(a []ma.Multiaddr) []string { + out := make([]string, len(a)) + for i, addr := range a { + out[i] = addr.String() } - return true + return out +} + +func addrsEqual(t assert.TestingT, a, b []ma.Multiaddr) bool { + aStr := multiaddrsToStrings(a) + bStr := multiaddrsToStrings(b) + return assert.ElementsMatch(t, aStr, bStr) } func TestObservedAddrManager(t *testing.T) { @@ -107,7 +106,7 @@ func TestObservedAddrManager(t *testing.T) { o.Record(c3, observed) o.Record(c4, observed) require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observed}) + return addrsEqual(t, o.Addrs(), []ma.Multiaddr{observed}) }, 1*time.Second, 100*time.Millisecond) o.removeConn(c1) o.removeConn(c2) @@ -131,8 +130,35 @@ func TestObservedAddrManager(t *testing.T) { o.Record(c2, observedQuic) o.Record(c3, observedWebTransport) o.Record(c4, observedWebTransport) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport}) + addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{}) + }, 1*time.Second, 100*time.Millisecond) + o.removeConn(c1) + o.removeConn(c2) + o.removeConn(c3) + o.removeConn(c4) require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport}) + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) + + t.Run("WebTransport inferred from QUIC, with no WebTransport connections", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + observedQuic := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1") + inferredWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1/webtransport") + c1 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.1/udp/1/quic-v1")) + c2 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1")) + c3 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1")) + c4 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")) + o.Record(c1, observedQuic) + o.Record(c2, observedQuic) + o.Record(c3, observedQuic) + o.Record(c4, observedQuic) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, inferredWebTransport}) + addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredWebTransport}) }, 1*time.Second, 100*time.Millisecond) o.removeConn(c1) o.removeConn(c2) @@ -148,6 +174,7 @@ func TestObservedAddrManager(t *testing.T) { defer o.Close() observedQuic := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1") + inferredWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1/webtransport") const N = 4 // ActivationThresh var ob1, ob2 [N]connMultiaddrs @@ -165,8 +192,9 @@ func TestObservedAddrManager(t *testing.T) { // We should have a valid address now o.Record(ob1[N-1], observedQuic) o.Record(ob2[N-1], observedQuic) - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic}) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, inferredWebTransport}) + addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredWebTransport}) }, 2*time.Second, 100*time.Millisecond) // Now disconnect first observer group @@ -174,8 +202,8 @@ func TestObservedAddrManager(t *testing.T) { o.removeConn(ob1[i]) } time.Sleep(100 * time.Millisecond) - if !addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic}) { - t.Fatalf("address removed too earyly %v %v", o.Addrs(), observedQuic) + if !addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, inferredWebTransport}) { + t.Fatalf("address removed too earyl %v %v", o.Addrs(), observedQuic) } // Now disconnect the second group to check cleanup @@ -193,6 +221,8 @@ func TestObservedAddrManager(t *testing.T) { observedQuic1 := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1") observedQuic2 := ma.StringCast("/ip4/2.2.2.2/udp/3/quic-v1") + inferredWebTransport1 := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1/webtransport") + inferredWebTransport2 := ma.StringCast("/ip4/2.2.2.2/udp/3/quic-v1/webtransport") const N = 4 // ActivationThresh var ob1, ob2 [N]connMultiaddrs @@ -210,8 +240,9 @@ func TestObservedAddrManager(t *testing.T) { // We should have a valid address now o.Record(ob1[N-1], observedQuic1) o.Record(ob2[N-1], observedQuic2) - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic1, observedQuic2}) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic1, observedQuic2, inferredWebTransport1, inferredWebTransport2}) + addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredWebTransport1, inferredWebTransport2}) }, 2*time.Second, 100*time.Millisecond) // Now disconnect first observer group @@ -219,8 +250,8 @@ func TestObservedAddrManager(t *testing.T) { o.removeConn(ob1[i]) } time.Sleep(100 * time.Millisecond) - if !addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic2}) { - t.Fatalf("address removed too earyly %v %v", o.Addrs(), observedQuic2) + if !addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic2, inferredWebTransport2}) { + t.Fatalf("address removed too early %v %v", o.Addrs(), observedQuic2) } // Now disconnect the second group to check cleanup @@ -254,8 +285,8 @@ func TestObservedAddrManager(t *testing.T) { time.Sleep(20 * time.Millisecond) } - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport}) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport}) }, 1*time.Second, 100*time.Millisecond) tw, err := thinWaistForm(quic4ListenAddr) @@ -333,8 +364,8 @@ func TestObservedAddrManager(t *testing.T) { allAddrs = append(allAddrs, resTCPAddrs[:]...) allAddrs = append(allAddrs, resQuicAddrs[:]...) allAddrs = append(allAddrs, resWebTransportAddrs[:]...) - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), allAddrs) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), allAddrs) }, 1*time.Second, 100*time.Millisecond) for i := 0; i < N; i++ { @@ -350,6 +381,7 @@ func TestObservedAddrManager(t *testing.T) { t.Run("WebTransport certhash", func(t *testing.T) { o := newObservedAddrMgr() observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport") + inferredQUIC := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1") c1 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.1/udp/1/quic-v1/webtransport")) c2 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1/webtransport")) c3 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1/webtransport")) @@ -358,8 +390,8 @@ func TestObservedAddrManager(t *testing.T) { o.Record(c2, observedWebTransport) o.Record(c3, observedWebTransport) o.Record(c4, observedWebTransport) - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport}) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedWebTransport, inferredQUIC}) }, 1*time.Second, 100*time.Millisecond) o.removeConn(c1) o.removeConn(c2) @@ -375,14 +407,16 @@ func TestObservedAddrManager(t *testing.T) { defer o.Close() observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport") + inferredQUIC := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1") var udpConns [5 * maxExternalThinWaistAddrsPerLocalAddr]connMultiaddrs for i := 0; i < len(udpConns); i++ { udpConns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i))) o.Record(udpConns[i], observedWebTransport) time.Sleep(10 * time.Millisecond) } - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport}) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedWebTransport, inferredQUIC}) + addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredQUIC}) }, 1*time.Second, 100*time.Millisecond) tcpNAT, udpNAT := o.getNATType() @@ -409,8 +443,8 @@ func TestObservedAddrManager(t *testing.T) { } // At this point we have 20 groups with 5 observations for every connection // The output should remain stable - require.Eventually(t, func() bool { - return len(o.Addrs()) == 2*maxExternalThinWaistAddrsPerLocalAddr + require.EventuallyWithT(t, func(t *assert.CollectT) { + require.Equal(t, len(subtractFrom(o.Addrs(), o.appendInferredAddrs(nil, nil))), 2*maxExternalThinWaistAddrsPerLocalAddr) }, 1*time.Second, 100*time.Millisecond) tcpNAT, udpNAT := o.getNATType() @@ -457,14 +491,16 @@ func TestObservedAddrManager(t *testing.T) { sub, err := bus.Subscribe(new(event.EvtNATDeviceTypeChanged)) require.NoError(t, err) observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport") + inferredQUIC := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1") var udpConns [5 * maxExternalThinWaistAddrsPerLocalAddr]connMultiaddrs for i := 0; i < len(udpConns); i++ { udpConns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i))) o.Record(udpConns[i], observedWebTransport) time.Sleep(10 * time.Millisecond) } - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport}) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedWebTransport, inferredQUIC}) + addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredQUIC}) }, 1*time.Second, 100*time.Millisecond) var e interface{} @@ -551,8 +587,8 @@ func TestObservedAddrManager(t *testing.T) { allAddrs = append(allAddrs, resTCPAddrs[:]...) allAddrs = append(allAddrs, resQuicAddrs[:]...) allAddrs = append(allAddrs, resWebTransportAddrs[:]...) - require.Eventually(t, func() bool { - return addrsEqual(o.Addrs(), allAddrs) + require.EventuallyWithT(t, func(t *assert.CollectT) { + addrsEqual(t, o.Addrs(), allAddrs) }, 1*time.Second, 100*time.Millisecond) for i := 0; i < N; i++ {