Skip to content

Commit

Permalink
Infer addresses that share the same local thin waist
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Jul 31, 2024
1 parent 1fd9519 commit 6ec6085
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 49 deletions.
32 changes: 32 additions & 0 deletions p2p/protocol/identify/obsaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,36 @@ func (o *ObservedAddrManager) AddrsFor(addr ma.Multiaddr) (addrs []ma.Multiaddr)
return res
}

// appendInferredAddrs infers the external address of other transports that
// share the local thin waist with a transport that we have do observations for.
//
// e.g. If we have observations for a QUIC address on port 9000, and we are
// listening on the same interface and port 9000 for WebTransport, we can infer
// the external WebTransport address.
func (o *ObservedAddrManager) appendInferredAddrs(twToObserverSets map[string][]*observerSet, addrs []ma.Multiaddr) []ma.Multiaddr {
if twToObserverSets == nil {
twToObserverSets = make(map[string][]*observerSet)
for localTWStr := range o.externalAddrs {
twToObserverSets[localTWStr] = append(twToObserverSets[localTWStr], o.getTopExternalAddrs(localTWStr)...)
}
}
for _, a := range o.listenAddrs() {
if _, ok := o.localAddrs[string(a.Bytes())]; ok {
// We already have this address in the list
continue
}
a = o.normalize(a)
t, err := thinWaistForm(a)
if err != nil {
continue
}
for _, s := range twToObserverSets[string(t.TW.Bytes())] {
addrs = append(addrs, s.cacheMultiaddr(t.Rest))
}
}
return addrs
}

// Addrs return all activated observed addresses
func (o *ObservedAddrManager) Addrs() []ma.Multiaddr {
o.mu.RLock()
Expand All @@ -228,6 +258,8 @@ func (o *ObservedAddrManager) Addrs() []ma.Multiaddr {
addrs = append(addrs, s.cacheMultiaddr(t.Rest))
}
}

addrs = o.appendInferredAddrs(m, addrs)
return addrs
}

Expand Down
134 changes: 85 additions & 49 deletions p2p/protocol/identify/obsaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -35,35 +36,33 @@ func normalize(addr ma.Multiaddr) ma.Multiaddr {
}
}

func addrsEqual(a, b []ma.Multiaddr) bool {
if len(b) != len(a) {
return false
// subtractFrom takes the difference between two slices of multiaddrs (a - b)
func subtractFrom(a, b []ma.Multiaddr) []string {
bSet := make(map[string]struct{}, len(b))
for _, addr := range b {
bSet[string(addr.Bytes())] = struct{}{}
}
for _, x := range b {
found := false
for _, y := range a {
if y.Equal(x) {
found = true
break
}
}
if !found {
return false
out := make([]string, 0, len(a))
for _, addr := range a {
if _, ok := bSet[string(addr.Bytes())]; !ok {
out = append(out, addr.String())
}
}
for _, x := range a {
found := false
for _, y := range b {
if y.Equal(x) {
found = true
break
}
}
if !found {
return false
}
return out
}

func multiaddrsToStrings(a []ma.Multiaddr) []string {
out := make([]string, len(a))
for i, addr := range a {
out[i] = addr.String()
}
return true
return out
}

func addrsEqual(t assert.TestingT, a, b []ma.Multiaddr) bool {
aStr := multiaddrsToStrings(a)
bStr := multiaddrsToStrings(b)
return assert.ElementsMatch(t, aStr, bStr)
}

func TestObservedAddrManager(t *testing.T) {
Expand Down Expand Up @@ -107,7 +106,7 @@ func TestObservedAddrManager(t *testing.T) {
o.Record(c3, observed)
o.Record(c4, observed)
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observed})
return addrsEqual(t, o.Addrs(), []ma.Multiaddr{observed})
}, 1*time.Second, 100*time.Millisecond)
o.removeConn(c1)
o.removeConn(c2)
Expand All @@ -131,8 +130,35 @@ func TestObservedAddrManager(t *testing.T) {
o.Record(c2, observedQuic)
o.Record(c3, observedWebTransport)
o.Record(c4, observedWebTransport)
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport})
addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{})
}, 1*time.Second, 100*time.Millisecond)
o.removeConn(c1)
o.removeConn(c2)
o.removeConn(c3)
o.removeConn(c4)
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport})
return checkAllEntriesRemoved(o)
}, 1*time.Second, 100*time.Millisecond)
})

t.Run("WebTransport inferred from QUIC, with no WebTransport connections", func(t *testing.T) {
o := newObservedAddrMgr()
defer o.Close()
observedQuic := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1")
inferredWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1/webtransport")
c1 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.1/udp/1/quic-v1"))
c2 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1"))
c3 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1"))
c4 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"))
o.Record(c1, observedQuic)
o.Record(c2, observedQuic)
o.Record(c3, observedQuic)
o.Record(c4, observedQuic)
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, inferredWebTransport})
addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredWebTransport})
}, 1*time.Second, 100*time.Millisecond)
o.removeConn(c1)
o.removeConn(c2)
Expand All @@ -148,6 +174,7 @@ func TestObservedAddrManager(t *testing.T) {
defer o.Close()

observedQuic := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1")
inferredWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1/webtransport")

const N = 4 // ActivationThresh
var ob1, ob2 [N]connMultiaddrs
Expand All @@ -165,17 +192,18 @@ func TestObservedAddrManager(t *testing.T) {
// We should have a valid address now
o.Record(ob1[N-1], observedQuic)
o.Record(ob2[N-1], observedQuic)
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic})
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, inferredWebTransport})
addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredWebTransport})
}, 2*time.Second, 100*time.Millisecond)

// Now disconnect first observer group
for i := 0; i < N; i++ {
o.removeConn(ob1[i])
}
time.Sleep(100 * time.Millisecond)
if !addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic}) {
t.Fatalf("address removed too earyly %v %v", o.Addrs(), observedQuic)
if !addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, inferredWebTransport}) {
t.Fatalf("address removed too earyl %v %v", o.Addrs(), observedQuic)
}

// Now disconnect the second group to check cleanup
Expand All @@ -193,6 +221,8 @@ func TestObservedAddrManager(t *testing.T) {

observedQuic1 := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1")
observedQuic2 := ma.StringCast("/ip4/2.2.2.2/udp/3/quic-v1")
inferredWebTransport1 := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1/webtransport")
inferredWebTransport2 := ma.StringCast("/ip4/2.2.2.2/udp/3/quic-v1/webtransport")

const N = 4 // ActivationThresh
var ob1, ob2 [N]connMultiaddrs
Expand All @@ -210,17 +240,18 @@ func TestObservedAddrManager(t *testing.T) {
// We should have a valid address now
o.Record(ob1[N-1], observedQuic1)
o.Record(ob2[N-1], observedQuic2)
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic1, observedQuic2})
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic1, observedQuic2, inferredWebTransport1, inferredWebTransport2})
addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredWebTransport1, inferredWebTransport2})
}, 2*time.Second, 100*time.Millisecond)

// Now disconnect first observer group
for i := 0; i < N; i++ {
o.removeConn(ob1[i])
}
time.Sleep(100 * time.Millisecond)
if !addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic2}) {
t.Fatalf("address removed too earyly %v %v", o.Addrs(), observedQuic2)
if !addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic2, inferredWebTransport2}) {
t.Fatalf("address removed too early %v %v", o.Addrs(), observedQuic2)
}

// Now disconnect the second group to check cleanup
Expand Down Expand Up @@ -254,8 +285,8 @@ func TestObservedAddrManager(t *testing.T) {
time.Sleep(20 * time.Millisecond)
}

require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport})
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport})
}, 1*time.Second, 100*time.Millisecond)

tw, err := thinWaistForm(quic4ListenAddr)
Expand Down Expand Up @@ -333,8 +364,8 @@ func TestObservedAddrManager(t *testing.T) {
allAddrs = append(allAddrs, resTCPAddrs[:]...)
allAddrs = append(allAddrs, resQuicAddrs[:]...)
allAddrs = append(allAddrs, resWebTransportAddrs[:]...)
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), allAddrs)
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), allAddrs)
}, 1*time.Second, 100*time.Millisecond)

for i := 0; i < N; i++ {
Expand All @@ -350,6 +381,7 @@ func TestObservedAddrManager(t *testing.T) {
t.Run("WebTransport certhash", func(t *testing.T) {
o := newObservedAddrMgr()
observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport")
inferredQUIC := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1")
c1 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.1/udp/1/quic-v1/webtransport"))
c2 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1/webtransport"))
c3 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1/webtransport"))
Expand All @@ -358,8 +390,8 @@ func TestObservedAddrManager(t *testing.T) {
o.Record(c2, observedWebTransport)
o.Record(c3, observedWebTransport)
o.Record(c4, observedWebTransport)
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport})
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedWebTransport, inferredQUIC})
}, 1*time.Second, 100*time.Millisecond)
o.removeConn(c1)
o.removeConn(c2)
Expand All @@ -375,14 +407,16 @@ func TestObservedAddrManager(t *testing.T) {
defer o.Close()

observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport")
inferredQUIC := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1")
var udpConns [5 * maxExternalThinWaistAddrsPerLocalAddr]connMultiaddrs
for i := 0; i < len(udpConns); i++ {
udpConns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i)))
o.Record(udpConns[i], observedWebTransport)
time.Sleep(10 * time.Millisecond)
}
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport})
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedWebTransport, inferredQUIC})
addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredQUIC})
}, 1*time.Second, 100*time.Millisecond)

tcpNAT, udpNAT := o.getNATType()
Expand All @@ -409,8 +443,8 @@ func TestObservedAddrManager(t *testing.T) {
}
// At this point we have 20 groups with 5 observations for every connection
// The output should remain stable
require.Eventually(t, func() bool {
return len(o.Addrs()) == 2*maxExternalThinWaistAddrsPerLocalAddr
require.EventuallyWithT(t, func(t *assert.CollectT) {
require.Equal(t, len(subtractFrom(o.Addrs(), o.appendInferredAddrs(nil, nil))), 2*maxExternalThinWaistAddrsPerLocalAddr)
}, 1*time.Second, 100*time.Millisecond)

tcpNAT, udpNAT := o.getNATType()
Expand Down Expand Up @@ -457,14 +491,16 @@ func TestObservedAddrManager(t *testing.T) {
sub, err := bus.Subscribe(new(event.EvtNATDeviceTypeChanged))
require.NoError(t, err)
observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport")
inferredQUIC := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1")
var udpConns [5 * maxExternalThinWaistAddrsPerLocalAddr]connMultiaddrs
for i := 0; i < len(udpConns); i++ {
udpConns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i)))
o.Record(udpConns[i], observedWebTransport)
time.Sleep(10 * time.Millisecond)
}
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport})
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), []ma.Multiaddr{observedWebTransport, inferredQUIC})
addrsEqual(t, o.appendInferredAddrs(nil, nil), []ma.Multiaddr{inferredQUIC})
}, 1*time.Second, 100*time.Millisecond)

var e interface{}
Expand Down Expand Up @@ -551,8 +587,8 @@ func TestObservedAddrManager(t *testing.T) {
allAddrs = append(allAddrs, resTCPAddrs[:]...)
allAddrs = append(allAddrs, resQuicAddrs[:]...)
allAddrs = append(allAddrs, resWebTransportAddrs[:]...)
require.Eventually(t, func() bool {
return addrsEqual(o.Addrs(), allAddrs)
require.EventuallyWithT(t, func(t *assert.CollectT) {
addrsEqual(t, o.Addrs(), allAddrs)
}, 1*time.Second, 100*time.Millisecond)

for i := 0; i < N; i++ {
Expand Down

0 comments on commit 6ec6085

Please sign in to comment.