From 818a700cd843ba3ac836129836f552bd4e4c9e5d Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 27 Feb 2025 22:30:39 +0530 Subject: [PATCH 01/15] basichost: use autonatv2 to verify reachability This introduces `addrsReachabilityTracker` that tracks reachability on a set of addresses. It probes reachability for addresses periodically and has an exponential backoff in case there are too many errors or we don't have any valid autonatv2 peer. There's no smartness in the address selection logic currently. We just test all provided addresses. It also doesn't use the addresses provided by `AddrsFactory`, so currently there's no way to get a user provided address tested for reachability, something that would be a problem for dns addresses. I intend to introduce an alternative to `AddrsFactory`, something like, `AnnounceAddrs(addrs []ma.Multiaddr)` that's just appended to the set of addresses that we have, and check reachability for those addresses. There's only one method exposed in the BasicHost right now that's `ReachableAddrs() []ma.Multiaddr` that returns the host's reachable addrs. Users can also use the event `EvtHostReachableAddrsChanged` to be notified when any addrs reachability changes. --- core/event/reachability.go | 10 + p2p/host/basic/addrs_manager.go | 242 +++++-- p2p/host/basic/addrs_manager_test.go | 63 +- p2p/host/basic/addrs_reachability_tracker.go | 606 ++++++++++++++++++ .../basic/addrs_reachability_tracker_test.go | 597 +++++++++++++++++ p2p/host/basic/basic_host.go | 42 +- p2p/host/basic/basic_host_test.go | 2 + p2p/protocol/autonatv2/autonat.go | 96 ++- p2p/protocol/autonatv2/autonat_test.go | 51 +- p2p/protocol/autonatv2/client.go | 3 +- p2p/protocol/autonatv2/options.go | 9 + p2p/protocol/autonatv2/server.go | 2 +- 12 files changed, 1537 insertions(+), 186 deletions(-) create mode 100644 p2p/host/basic/addrs_reachability_tracker.go create mode 100644 p2p/host/basic/addrs_reachability_tracker_test.go diff --git a/core/event/reachability.go b/core/event/reachability.go index 6ab4523851..8aa2cd07cc 100644 --- a/core/event/reachability.go +++ b/core/event/reachability.go @@ -2,6 +2,7 @@ package event import ( "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" ) // EvtLocalReachabilityChanged is an event struct to be emitted when the local's @@ -11,3 +12,12 @@ import ( type EvtLocalReachabilityChanged struct { Reachability network.Reachability } + +// EvtHostReachableAddrsChanged is sent when host's reachable or unreachable addresses change +// Reachable and Unreachable both contain only Public IP or DNS addresses +// +// Experimental: This API is unstable. Any changes to this event will be done without a deprecation notice. +type EvtHostReachableAddrsChanged struct { + Reachable []ma.Multiaddr + Unreachable []ma.Multiaddr +} diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index 6b984a9dd5..3c09c98bb8 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/libp2p/go-netroute" @@ -27,24 +28,38 @@ type observedAddrsManager interface { ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr } +type hostAddrs struct { + addrs []ma.Multiaddr + localAddrs []ma.Multiaddr + reachableAddrs []ma.Multiaddr + unreachableAddrs []ma.Multiaddr +} + type addrsManager struct { - eventbus event.Bus - natManager NATManager - addrsFactory AddrsFactory - listenAddrs func() []ma.Multiaddr - transportForListening func(ma.Multiaddr) transport.Transport - observedAddrsManager observedAddrsManager - interfaceAddrs *interfaceAddrsCache + bus event.Bus + natManager NATManager + addrsFactory AddrsFactory + listenAddrs func() []ma.Multiaddr + transportForListening func(ma.Multiaddr) transport.Transport + observedAddrsManager observedAddrsManager + interfaceAddrs *interfaceAddrsCache + addrsReachabilityTracker *addrsReachabilityTracker + + // addrsUpdatedChan is notified when addrs change. This is provided by the caller. + addrsUpdatedChan chan struct{} // triggerAddrsUpdateChan is used to trigger an addresses update. triggerAddrsUpdateChan chan struct{} - // addrsUpdatedChan is notified when addresses change. - addrsUpdatedChan chan struct{} + // triggerReachabilityUpdate is notified when reachable addrs are updated. + triggerReachabilityUpdate chan struct{} + // triggerHostReachabilityUpdate is notified when host's reachability from autonat v1 changes. + triggerHostReachabilityUpdate chan struct{} + hostReachability atomic.Pointer[network.Reachability] - addrsMx sync.RWMutex // protects fields below - localAddrs []ma.Multiaddr - relayAddrs []ma.Multiaddr + addrsMx sync.RWMutex // protects fields below + currentAddrs hostAddrs + relayAddrs []ma.Multiaddr wg sync.WaitGroup ctx context.Context @@ -52,35 +67,54 @@ type addrsManager struct { } func newAddrsManager( - eventbus event.Bus, + bus event.Bus, natmgr NATManager, addrsFactory AddrsFactory, listenAddrs func() []ma.Multiaddr, transportForListening func(ma.Multiaddr) transport.Transport, observedAddrsManager observedAddrsManager, addrsUpdatedChan chan struct{}, + client autonatv2Client, ) (*addrsManager, error) { ctx, cancel := context.WithCancel(context.Background()) as := &addrsManager{ - eventbus: eventbus, - listenAddrs: listenAddrs, - transportForListening: transportForListening, - observedAddrsManager: observedAddrsManager, - natManager: natmgr, - addrsFactory: addrsFactory, - triggerAddrsUpdateChan: make(chan struct{}, 1), - addrsUpdatedChan: addrsUpdatedChan, - interfaceAddrs: &interfaceAddrsCache{}, - ctx: ctx, - ctxCancel: cancel, + bus: bus, + listenAddrs: listenAddrs, + transportForListening: transportForListening, + observedAddrsManager: observedAddrsManager, + natManager: natmgr, + addrsFactory: addrsFactory, + triggerAddrsUpdateChan: make(chan struct{}, 1), + triggerHostReachabilityUpdate: make(chan struct{}, 1), + triggerReachabilityUpdate: make(chan struct{}, 1), + addrsUpdatedChan: addrsUpdatedChan, + interfaceAddrs: &interfaceAddrsCache{}, + ctx: ctx, + ctxCancel: cancel, } unknownReachability := network.ReachabilityUnknown as.hostReachability.Store(&unknownReachability) + + if client != nil { + as.addrsReachabilityTracker = newAddrsReachabilityTracker(client, as.triggerReachabilityUpdate, nil) + } return as, nil } func (a *addrsManager) Start() error { - return a.background() + // TODO: add Start method to NATMgr + if a.addrsReachabilityTracker != nil { + err := a.addrsReachabilityTracker.Start() + if err != nil { + return fmt.Errorf("error starting addrs reachability tracker: %s", err) + } + } + + err := a.background() + if err != nil { + return err + } + return nil } func (a *addrsManager) Close() { @@ -91,10 +125,18 @@ func (a *addrsManager) Close() { log.Warnf("error closing natmgr: %s", err) } } + if a.addrsReachabilityTracker != nil { + err := a.addrsReachabilityTracker.Close() + if err != nil { + log.Warnf("error closing addrs reachability tracker: %s", err) + } + } a.wg.Wait() } func (a *addrsManager) NetNotifee() network.Notifiee { + // Updating addrs in sync provides the nice property that + // host.Addrs() just after host.Network().Listen(x) will return x return &network.NotifyBundle{ ListenF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() }, ListenCloseF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() }, @@ -102,9 +144,7 @@ func (a *addrsManager) NetNotifee() network.Notifiee { } func (a *addrsManager) triggerAddrsUpdate() { - // This is ugly, we update here *and* in the background loop, but this ensures the nice property - // that host.Addrs after host.Network().Listen(...) will return the recently added listen address. - a.updateLocalAddrs() + a.updateAddrs() select { case a.triggerAddrsUpdateChan <- struct{}{}: default: @@ -112,19 +152,21 @@ func (a *addrsManager) triggerAddrsUpdate() { } func (a *addrsManager) background() error { - autoRelayAddrsSub, err := a.eventbus.Subscribe(new(event.EvtAutoRelayAddrsUpdated)) + autoRelayAddrsSub, err := a.bus.Subscribe(new(event.EvtAutoRelayAddrsUpdated), eventbus.Name("addrs-manager")) if err != nil { return fmt.Errorf("error subscribing to auto relay addrs: %s", err) } - autonatReachabilitySub, err := a.eventbus.Subscribe(new(event.EvtLocalReachabilityChanged)) + autonatReachabilitySub, err := a.bus.Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("addrs-manager")) if err != nil { return fmt.Errorf("error subscribing to autonat reachability: %s", err) } - // ensure that we have the correct address after returning from Start() - // update local addrs - a.updateLocalAddrs() + emitter, err := a.bus.Emitter(new(event.EvtHostReachableAddrsChanged), eventbus.Stateful) + if err != nil { + return fmt.Errorf("error creating host reachable addrs emitter: %w", err) + } + // update relay addrs in case we're private select { case e := <-autoRelayAddrsSub.Out(): @@ -133,6 +175,7 @@ func (a *addrsManager) background() error { } default: } + select { case e := <-autonatReachabilitySub.Out(): if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { @@ -140,6 +183,7 @@ func (a *addrsManager) background() error { } default: } + a.updateAddrs() a.wg.Add(1) go func() { @@ -159,18 +203,11 @@ func (a *addrsManager) background() error { ticker := time.NewTicker(addrChangeTickrInterval) defer ticker.Stop() - var prev []ma.Multiaddr + var previousAddrs hostAddrs for { - a.updateLocalAddrs() - curr := a.Addrs() - if a.areAddrsDifferent(prev, curr) { - log.Debugf("host addresses updated: %s", curr) - select { - case a.addrsUpdatedChan <- struct{}{}: - default: - } - } - prev = curr + currAddrs := a.updateAddrs() + a.notifyAddrsChanged(emitter, previousAddrs, currAddrs) + previousAddrs = currAddrs select { case <-ticker.C: case <-a.triggerAddrsUpdateChan: @@ -178,6 +215,7 @@ func (a *addrsManager) background() error { if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { a.updateRelayAddrs(evt.RelayAddrs) } + case <-a.triggerReachabilityUpdate: case e := <-autonatReachabilitySub.Out(): if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { a.hostReachability.Store(&evt.Reachability) @@ -190,20 +228,72 @@ func (a *addrsManager) background() error { return nil } +func (a *addrsManager) updateAddrs() hostAddrs { + localAddrs := a.getLocalAddrs() + var currReachableAddrs, currUnreachableAddrs []ma.Multiaddr + if a.addrsReachabilityTracker != nil { + currReachableAddrs, currUnreachableAddrs = a.getConfirmedAddrs(localAddrs) + } + currAddrs := a.getAddrs(slices.Clone(localAddrs), a.RelayAddrs()) + + // maybe we can avoid this clone? + a.addrsMx.Lock() + a.currentAddrs.addrs = append(a.currentAddrs.addrs[:0], currAddrs...) + a.currentAddrs.localAddrs = append(a.currentAddrs.localAddrs[:0], localAddrs...) + a.currentAddrs.reachableAddrs = append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...) + a.currentAddrs.unreachableAddrs = append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...) + a.addrsMx.Unlock() + + return hostAddrs{ + localAddrs: localAddrs, + addrs: currAddrs, + reachableAddrs: currReachableAddrs, + unreachableAddrs: currUnreachableAddrs, + } +} + +func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) { + if areAddrsDifferent(previous.localAddrs, current.localAddrs) { + log.Debugf("host addresses updated: %s", current.localAddrs) + if a.addrsReachabilityTracker != nil { + a.addrsReachabilityTracker.UpdateAddrs(current.localAddrs) + } + } + if areAddrsDifferent(previous.addrs, current.addrs) { + select { + case a.addrsUpdatedChan <- struct{}{}: + default: + } + } + + if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) || + areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) { + if err := emitter.Emit(event.EvtHostReachableAddrsChanged{ + Reachable: slices.Clone(current.reachableAddrs), + Unreachable: slices.Clone(current.unreachableAddrs), + }); err != nil { + log.Errorf("error sending host reachable addrs changed event: %s", err) + } + } +} + // Addrs returns the node's dialable addresses both public and private. // If autorelay is enabled and node reachability is private, it returns // the node's relay addresses and private network addresses. func (a *addrsManager) Addrs() []ma.Multiaddr { - addrs := a.DirectAddrs() + return a.getAddrs(a.DirectAddrs(), a.RelayAddrs()) +} + +// mutates localAddrs +func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr { + addrs := localAddrs rch := a.hostReachability.Load() if rch != nil && *rch == network.ReachabilityPrivate { - a.addrsMx.RLock() // Delete public addresses if the node's reachability is private, and we have relay addresses - if len(a.relayAddrs) > 0 { + if len(relayAddrs) > 0 { addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr) - addrs = append(addrs, a.relayAddrs...) + addrs = append(addrs, relayAddrs...) } - a.addrsMx.RUnlock() } // Make a copy. Consumers can modify the slice elements addrs = slices.Clone(a.addrsFactory(addrs)) @@ -212,8 +302,6 @@ func (a *addrsManager) Addrs() []ma.Multiaddr { slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) return addrs } - -// HolePunchAddrs returns the node's public direct listen addresses for hole punching. func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr { addrs := a.DirectAddrs() addrs = slices.Clone(a.addrsFactory(addrs)) @@ -230,7 +318,20 @@ func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr { func (a *addrsManager) DirectAddrs() []ma.Multiaddr { a.addrsMx.RLock() defer a.addrsMx.RUnlock() - return slices.Clone(a.localAddrs) + return slices.Clone(a.currentAddrs.localAddrs) +} + +// ReachableAddrs returns all addresses of the host that are reachable from the internet +func (a *addrsManager) ReachableAddrs() []ma.Multiaddr { + a.addrsMx.RLock() + defer a.addrsMx.RUnlock() + return slices.Clone(a.currentAddrs.reachableAddrs) +} + +func (a *addrsManager) RelayAddrs() []ma.Multiaddr { + a.addrsMx.RLock() + defer a.addrsMx.RUnlock() + return slices.Clone(a.relayAddrs) } func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { @@ -239,17 +340,21 @@ func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { a.relayAddrs = append(a.relayAddrs[:0], addrs...) } -var p2pCircuitAddr = ma.StringCast("/p2p-circuit") - -func (a *addrsManager) updateLocalAddrs() { - localAddrs := a.getLocalAddrs() - slices.SortFunc(localAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) - - a.addrsMx.Lock() - a.localAddrs = localAddrs - a.addrsMx.Unlock() +func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + reachableAddrs, unreachableAddrs = a.addrsReachabilityTracker.ConfirmedAddrs() + // Only include relevant host addresses as the reachability manager may have + // a stale view of host's addresses. + reachableAddrs = slices.DeleteFunc(reachableAddrs, func(a ma.Multiaddr) bool { + return !contains(localAddrs, a) + }) + unreachableAddrs = slices.DeleteFunc(unreachableAddrs, func(a ma.Multiaddr) bool { + return !contains(localAddrs, a) + }) + return reachableAddrs, unreachableAddrs } +var p2pCircuitAddr = ma.StringCast("/p2p-circuit") + func (a *addrsManager) getLocalAddrs() []ma.Multiaddr { listenAddrs := a.listenAddrs() if len(listenAddrs) == 0 { @@ -260,8 +365,6 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr { finalAddrs = a.appendPrimaryInterfaceAddrs(finalAddrs, listenAddrs) finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs, a.interfaceAddrs.All()) - finalAddrs = ma.Unique(finalAddrs) - // Remove "/p2p-circuit" addresses from the list. // The p2p-circuit listener reports its address as just /p2p-circuit. This is // useless for dialing. Users need to manage their circuit addresses themselves, @@ -278,6 +381,8 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr { // Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered // using identify. finalAddrs = a.addCertHashes(finalAddrs) + finalAddrs = ma.Unique(finalAddrs) + slices.SortFunc(finalAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) return finalAddrs } @@ -408,7 +513,7 @@ func (a *addrsManager) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } -func (a *addrsManager) areAddrsDifferent(prev, current []ma.Multiaddr) bool { +func areAddrsDifferent(prev, current []ma.Multiaddr) bool { // TODO: make the sorted nature of ma.Unique a guarantee in multiaddrs prev = ma.Unique(prev) current = ma.Unique(current) @@ -425,6 +530,15 @@ func (a *addrsManager) areAddrsDifferent(prev, current []ma.Multiaddr) bool { return false } +func contains(addrs []ma.Multiaddr, addr ma.Multiaddr) bool { + for _, a := range addrs { + if a.Equal(addr) { + return true + } + } + return false +} + const interfaceAddrsCacheTTL = time.Minute type interfaceAddrsCache struct { diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index 49e46f2530..9b01149dae 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -1,6 +1,8 @@ package basichost import ( + "context" + "errors" "fmt" "testing" "time" @@ -8,6 +10,8 @@ import ( "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/assert" @@ -170,6 +174,8 @@ type addrsManagerArgs struct { AddrsFactory AddrsFactory ObservedAddrsManager observedAddrsManager ListenAddrs func() []ma.Multiaddr + AutoNATClient autonatv2Client + Bus event.Bus } type addrsManagerTestCase struct { @@ -179,13 +185,16 @@ type addrsManagerTestCase struct { } func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTestCase { - eb := eventbus.NewBus() + eb := args.Bus + if eb == nil { + eb = eventbus.NewBus() + } if args.AddrsFactory == nil { args.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } } addrsUpdatedChan := make(chan struct{}, 1) am, err := newAddrsManager( - eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, + eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, args.AutoNATClient, ) require.NoError(t, err) @@ -196,6 +205,7 @@ func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTe rchEm, err := eb.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) require.NoError(t, err) + t.Cleanup(am.Close) return addrsManagerTestCase{ addrsManager: am, PushRelay: func(relayAddrs []ma.Multiaddr) { @@ -425,17 +435,62 @@ func TestAddrsManager(t *testing.T) { }) } +func TestAddrsManagerReachabilityEvent(t *testing.T) { + // Setup test addresses + publicQUIC, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1") + publicTCP, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") + + // Create a new event bus + bus := eventbus.NewBus() + + // Subscribe to EvtHostReachableAddrsChanged events + sub, err := bus.Subscribe(new(event.EvtHostReachableAddrsChanged)) + require.NoError(t, err) + defer sub.Close() + + am := newAddrsManagerTestCase(t, addrsManagerArgs{ + Bus: bus, + // currently they aren't being passed to the reachability tracker + ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicTCP} }, + AutoNATClient: mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + if reqs[0].Addr.Equal(publicQUIC) { + return autonatv2.Result{Addr: reqs[0].Addr, Status: pb.DialStatus_OK}, nil + } else if reqs[0].Addr.Equal(publicTCP) { + return autonatv2.Result{Addr: reqs[0].Addr, Status: pb.DialStatus_E_DIAL_ERROR}, nil + } + t.Errorf("received invalid request for addr: %+v", reqs[0]) + return autonatv2.Result{}, errors.New("invalid") + }, + }, + }) + + reachableAddrs := []ma.Multiaddr{publicQUIC} + unreachableAddrs := []ma.Multiaddr{publicTCP} + + // No new event should be received + select { + case e := <-sub.Out(): + evt := e.(event.EvtHostReachableAddrsChanged) + // Verify the event contains the expected addresses + require.ElementsMatch(t, reachableAddrs, evt.Reachable) + require.ElementsMatch(t, unreachableAddrs, evt.Unreachable) + require.ElementsMatch(t, reachableAddrs, am.ReachableAddrs()) + case <-time.After(5 * time.Second): + t.Fatal("expected event for reachability change") + } +} + func BenchmarkAreAddrsDifferent(b *testing.B) { var addrs [10]ma.Multiaddr for i := 0; i < len(addrs); i++ { addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i)) } - am := &addrsManager{} b.Run("areAddrsDifferent", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - am.areAddrsDifferent(addrs[:], addrs[:]) + areAddrsDifferent(addrs[:], addrs[:]) } }) } diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go new file mode 100644 index 0000000000..e0b385dc4b --- /dev/null +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -0,0 +1,606 @@ +package basichost + +import ( + "context" + "errors" + "fmt" + "math" + "slices" + "sync" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +type autonatv2Client interface { + GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) +} + +const ( + + // maxAddrsPerRequest is the maximum number of addresses to probe in a single request + maxAddrsPerRequest = 10 + // maxTrackedAddrs is the maximum number of addresses to track + // 10 addrs per transport for 5 transports + maxTrackedAddrs = 50 + // defaultMaxConcurrency is the default number of concurrent workers for reachability checks + defaultMaxConcurrency = 5 +) + +type addrsReachabilityTracker struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + cli autonatv2Client + // reachabilityUpdateCh is used to notify when reachability may have changed + reachabilityUpdateCh chan struct{} + maxConcurrency int + newAddrsProbeDelay time.Duration + addrTracker *addrsProbeTracker + newAddrs chan []ma.Multiaddr + clock clock.Clock + + mx sync.Mutex + reachableAddrs []ma.Multiaddr + unreachableAddrs []ma.Multiaddr +} + +// newAddrsReachabilityTracker tracks reachability for addresses. +// Use UpdateAddrs to provide addresses for tracking reachability. +// reachabilityUpdateCh is notified when any reachability probes are made. The reader must dedup the events. It may be +// notified even when the reachability for any addrs has not changed. +func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh chan struct{}, cl clock.Clock) *addrsReachabilityTracker { + ctx, cancel := context.WithCancel(context.Background()) + if cl == nil { + cl = clock.New() + } + return &addrsReachabilityTracker{ + ctx: ctx, + cancel: cancel, + cli: client, + reachabilityUpdateCh: reachabilityUpdateCh, + addrTracker: newAddrsTracker(cl.Now, maxRecentProbeResultWindow), + newAddrsProbeDelay: 1 * time.Second, + maxConcurrency: defaultMaxConcurrency, + newAddrs: make(chan []ma.Multiaddr, 1), + clock: cl, + } +} + +func (r *addrsReachabilityTracker) UpdateAddrs(addrs []ma.Multiaddr) { + r.newAddrs <- slices.Clone(addrs) +} + +func (r *addrsReachabilityTracker) Start() error { + r.wg.Add(1) + err := r.background() + if err != nil { + return err + } + return nil +} + +func (r *addrsReachabilityTracker) Close() error { + r.cancel() + r.wg.Wait() + return nil +} + +const defaultResetInterval = 5 * time.Minute + +func (r *addrsReachabilityTracker) background() error { + go func() { + defer r.wg.Done() + + timer := r.clock.Timer(time.Duration(math.MaxInt64)) + defer timer.Stop() + + var task reachabilityTask + var backoffInterval time.Duration + var reachable, unreachable []ma.Multiaddr // used to avoid allocations + for { + select { + case <-timer.C: + if task.RespCh == nil { + task = r.refreshReachability() + } + timer.Reset(defaultResetInterval) + case backoff := <-task.RespCh: + task = reachabilityTask{} + if backoff { + backoffInterval = newBackoffInterval(backoffInterval) + } else { + backoffInterval = 0 + } + reachable, unreachable = r.appendConfirmedAddrsAndNotify(reachable[:0], unreachable[:0]) + timer.Reset(backoffInterval) + case addrs := <-r.newAddrs: + if task.RespCh != nil { + task.Cancel() + <-task.RespCh + task = reachabilityTask{} + // We must send the event here. If there are no new addrs in this event we may not probe + // again for a while delaying any reachability updates. + reachable, unreachable = r.appendConfirmedAddrsAndNotify(reachable[:0], unreachable[:0]) + } + addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { + return !manet.IsPublicAddr(a) + }) + if len(addrs) > maxTrackedAddrs { + log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs) + addrs = addrs[:maxTrackedAddrs] + } + r.addrTracker.UpdateAddrs(addrs) + timer.Reset(r.newAddrsProbeDelay) + case <-r.ctx.Done(): + if task.RespCh != nil { + task.Cancel() + <-task.RespCh + task = reachabilityTask{} + } + return + } + } + }() + return nil +} + +func (r *addrsReachabilityTracker) appendConfirmedAddrsAndNotify(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + reachable, unreachable = r.addrTracker.AppendConfirmedAddrs(reachable, unreachable) + r.mx.Lock() + r.reachableAddrs = append(r.reachableAddrs[:0], reachable...) + r.unreachableAddrs = append(r.unreachableAddrs[:0], unreachable...) + r.mx.Unlock() + select { + case r.reachabilityUpdateCh <- struct{}{}: + default: + } + return reachable, unreachable +} + +func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + r.mx.Lock() + defer r.mx.Unlock() + return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs) +} + +const ( + backoffStartInterval = 5 * time.Second + maxBackoffInterval = 2 * defaultResetInterval +) + +func newBackoffInterval(current time.Duration) time.Duration { + if current == 0 { + return backoffStartInterval + } + current *= 2 + if current > maxBackoffInterval { + return maxBackoffInterval + } + return current +} + +type reachabilityTask struct { + Cancel context.CancelFunc + RespCh chan bool +} + +func (r *addrsReachabilityTracker) refreshReachability() reachabilityTask { + if len(r.addrTracker.GetProbe()) == 0 { + return reachabilityTask{} + } + resCh := make(chan bool, 1) + ctx, cancel := context.WithTimeout(r.ctx, 5*time.Minute) + r.wg.Add(1) + go func() { + defer r.wg.Done() + defer cancel() + backoff := runProbes(ctx, r.maxConcurrency, r.addrTracker, r.cli) + resCh <- backoff + }() + return reachabilityTask{Cancel: cancel, RespCh: resCh} +} + +var errTooManyConsecutiveFailures = errors.New("too many consecutive failures") + +// errCountingClient counts errors from autonatv2Client and wraps the errors in response with a +// errTooManyConsecutiveFailures in case of many consecutive failures +type errCountingClient struct { + autonatv2Client + MaxConsecutiveErrors int + mx sync.Mutex + consecutiveErrors int +} + +func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + res, err := c.autonatv2Client.GetReachability(ctx, reqs) + c.mx.Lock() + defer c.mx.Unlock() + if err == nil || errors.Is(err, autonatv2.ErrDialRefused) || errors.Is(err, autonatv2.ErrNoValidPeers) { + c.consecutiveErrors = 0 + } else { + c.consecutiveErrors++ + if c.consecutiveErrors > c.MaxConsecutiveErrors { + err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err) + } + } + return res, err +} + +type probeResponse struct { + Requests []autonatv2.Request + Result autonatv2.Result + Err error +} + +const maxConsecutiveErrors = 20 + +// runProbes runs probes provided by addrsTracker with the given client. It returns true if the caller should +// backoff before retrying probes. It stops probing when any of the following happens: +// - there are no more probes to run +// - context is completed +// - there are too many consecutive failures from the client +// - the client has no valid peers to probe +func runProbes(ctx context.Context, concurrency int, addrsTracker *addrsProbeTracker, client autonatv2Client) bool { + client = &errCountingClient{autonatv2Client: client, MaxConsecutiveErrors: maxConsecutiveErrors} + + resultsCh := make(chan probeResponse, 2*concurrency) // enough buffer to allow all worker goroutines to exit quickly + jobsCh := make(chan []autonatv2.Request, 1) // close jobs to terminate the workers + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + for reqs := range jobsCh { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + res, err := client.GetReachability(ctx, reqs) + cancel() + resultsCh <- probeResponse{Requests: reqs, Result: res, Err: err} + } + }() + } + + nextProbe := addrsTracker.GetProbe() + backoff := false +outer: + for jc := jobsCh; addrsTracker.InProgressProbes() > 0 || len(nextProbe) > 0; { + select { + case jc <- nextProbe: + addrsTracker.MarkProbeInProgress(nextProbe) + case resp := <-resultsCh: + addrsTracker.CompleteProbe(resp.Requests, resp.Result, resp.Err) + if errors.Is(resp.Err, autonatv2.ErrNoValidPeers) || errors.Is(resp.Err, errTooManyConsecutiveFailures) { + backoff = true + break outer + } + case <-ctx.Done(): + break outer + } + jc = jobsCh + nextProbe = addrsTracker.GetProbe() + if len(nextProbe) == 0 { + jc = nil + } + } + close(jobsCh) + for addrsTracker.InProgressProbes() > 0 { + resp := <-resultsCh + addrsTracker.CompleteProbe(resp.Requests, resp.Result, resp.Err) + if errors.Is(resp.Err, autonatv2.ErrNoValidPeers) || errors.Is(resp.Err, errTooManyConsecutiveFailures) { + backoff = true + } + } + wg.Wait() + return backoff +} + +// addrsProbeTracker tracks reachability for a set of addresses. This struct decides the priority order of +// addresses for testing reachability. +// +// To execute the probes with a client use the `runProbes` function. +// +// Probes returned by `GetProbe` should be marked as in progress using `MarkProbeInProgress` +// before being executed. +type addrsProbeTracker struct { + now func() time.Time + recentProbeResultWindow int + + mx sync.Mutex + inProgressProbes map[string]int // addr -> count + inProgressProbesTotal int + statuses map[string]*addrStatus + addrs []ma.Multiaddr +} + +func newAddrsTracker(now func() time.Time, recentProbeResultWindow int) *addrsProbeTracker { + return &addrsProbeTracker{ + statuses: make(map[string]*addrStatus), + inProgressProbes: make(map[string]int), + now: now, + recentProbeResultWindow: recentProbeResultWindow, + } +} + +// AppendConfirmedAddrs appends the current confirmed reachable and unreachable addresses. +func (t *addrsProbeTracker) AppendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + t.mx.Lock() + defer t.mx.Unlock() + + t.gc() + for _, as := range t.statuses { + switch as.Reachability() { + case network.ReachabilityPublic: + reachable = append(reachable, as.Addr) + case network.ReachabilityPrivate: + unreachable = append(unreachable, as.Addr) + } + } + return reachable, unreachable +} + +func (t *addrsProbeTracker) UpdateAddrs(addrs []ma.Multiaddr) { + t.mx.Lock() + defer t.mx.Unlock() + for _, addr := range addrs { + if _, ok := t.statuses[string(addr.Bytes())]; !ok { + t.statuses[string(addr.Bytes())] = &addrStatus{Addr: addr} + } + } + for k, s := range t.statuses { + found := false + for _, a := range addrs { + if a.Equal(s.Addr) { + found = true + break + } + } + if !found { + delete(t.statuses, k) + } + } + t.addrs = addrs +} + +func (t *addrsProbeTracker) GetProbe() []autonatv2.Request { + t.mx.Lock() + defer t.mx.Unlock() + + reqs := make([]autonatv2.Request, 0, maxAddrsPerRequest) + now := t.now() + for _, a := range t.addrs { + akey := string(a.Bytes()) + pc := t.statuses[akey].ProbeCount(now) + if pc == 0 { + continue + } + if len(reqs) == 0 && t.inProgressProbes[akey] >= pc { + continue + } + reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true}) + if len(reqs) >= maxAddrsPerRequest { + break + } + } + return reqs +} + +// MarkProbeInProgress should be called when a probe is started. +func (t *addrsProbeTracker) MarkProbeInProgress(reqs []autonatv2.Request) { + if len(reqs) == 0 { + return + } + t.mx.Lock() + defer t.mx.Unlock() + t.inProgressProbes[string(reqs[0].Addr.Bytes())]++ + t.inProgressProbesTotal++ +} + +// InProgressProbes returns the number of probes that are currently in progress. +func (t *addrsProbeTracker) InProgressProbes() int { + t.mx.Lock() + defer t.mx.Unlock() + return t.inProgressProbesTotal +} + +// CompleteProbe should be called when a probe completes. +func (t *addrsProbeTracker) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Result, err error) { + now := t.now() + + if len(reqs) == 0 { + // should never happen + return + } + + t.mx.Lock() + defer t.mx.Unlock() + + // decrement in-progress count for the first address + primaryAddrKey := string(reqs[0].Addr.Bytes()) + t.inProgressProbes[primaryAddrKey]-- + t.inProgressProbesTotal-- + if t.inProgressProbes[primaryAddrKey] <= 0 { + delete(t.inProgressProbes, primaryAddrKey) + } + + // request failed + if err != nil { + // request refused + if errors.Is(err, autonatv2.ErrDialRefused) { + for _, req := range reqs { + if status, ok := t.statuses[string(req.Addr.Bytes())]; ok { + status.AddRefusal(now) + } + } + } + return + } + + // mark addresses that were skipped as refused + for _, req := range reqs { + if req.Addr.Equal(res.Addr) { + break + } + if status, ok := t.statuses[string(req.Addr.Bytes())]; ok { + status.AddRefusal(now) + } + } + + // record the result for the probed address + if status, ok := t.statuses[string(res.Addr.Bytes())]; ok { + switch res.Status { + case pb.DialStatus_OK: + status.AddResult(now, true) + case pb.DialStatus_E_DIAL_ERROR: + status.AddResult(now, false) + default: + log.Debug("unexpected dial status", res.Addr, res.Status) + } + status.Trim(t.recentProbeResultWindow) + } +} + +func (t *addrsProbeTracker) gc() { + expireBefore := t.now().Add(-maxProbeResultTTL) + for _, s := range t.statuses { + s.ExpireBefore(expireBefore) + } +} + +type probeResult struct { + Time time.Time + Success bool +} + +const ( + // maxProbeResultTTL is the maximum time to keep probe results for an address + maxProbeResultTTL = 3 * time.Hour + // maxProbeInterval is the maximum interval between probes for an address + maxProbeInterval = 1 * time.Hour + // addrRefusedProbeInterval is the interval to probe addresses that have been refused + // these are generally addresses with newer transports for which we don't have many peers + // capable of dialing the transport + addrRefusedProbeInterval = 10 * time.Minute + // maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which + // we wait for `addrRefusedProbeInterval` before probing again + maxConsecutiveRefusals = 5 + // confidence is the absolute difference between the number of successes and failures for an address + // targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval` + // before probing again + targetConfidence = 3 + // minConfidence is the confidence threshold for an address to be considered reachable or unreachable + // confidence is the absolute difference between the number of successes and failures for an address + minConfidence = 2 + // maxRecentProbeResultWindow is the maximum number of recent probe results to consider for a single address + // + // +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure + // and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to + // targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval. + maxRecentProbeResultWindow = targetConfidence + 2 +) + +type addrStatus struct { + Addr ma.Multiaddr + results []probeResult + consecutiveRefusals struct { + Count int + Last time.Time + } +} + +func (s *addrStatus) Reachability() network.Reachability { + successes, failures := s.resultCounts() + return s.reachability(successes, failures) +} + +func (*addrStatus) reachability(success, failures int) network.Reachability { + if success-failures >= minConfidence { + return network.ReachabilityPublic + } + if failures-success >= minConfidence { + return network.ReachabilityPrivate + } + return network.ReachabilityUnknown +} + +func (s *addrStatus) ProbeCount(now time.Time) int { + // if we have had too many consecutive refusals, probe after a small wait. + if s.consecutiveRefusals.Count >= maxConsecutiveRefusals { + if s.consecutiveRefusals.Last.Add(addrRefusedProbeInterval).Before(now) { + return 1 + } + return 0 + } + + successes, failures := s.resultCounts() + cnt := 0 + if successes >= failures { + cnt = targetConfidence - (successes - failures) + } + if failures >= successes { + cnt = targetConfidence - (failures - successes) + } + if cnt <= 0 { + if len(s.results) == 0 { + return 0 + } + if s.results[len(s.results)-1].Time.Add(maxProbeInterval).Before(now) { + return 1 + } + // Last probe result was different from reachability. Probe again. + switch s.reachability(successes, failures) { + case network.ReachabilityPublic: + if !s.results[len(s.results)-1].Success { + return 1 + } + case network.ReachabilityPrivate: + if s.results[len(s.results)-1].Success { + return 1 + } + } + return 0 + } + return cnt +} + +func (s *addrStatus) resultCounts() (successes, failures int) { + for _, r := range s.results { + if r.Success { + successes++ + } else { + failures++ + } + } + return successes, failures +} + +func (s *addrStatus) ExpireBefore(before time.Time) { + s.results = slices.DeleteFunc(s.results, func(pr probeResult) bool { + return pr.Time.Before(before) + }) +} + +func (s *addrStatus) AddResult(at time.Time, success bool) { + s.results = append(s.results, probeResult{ + Success: success, + Time: at, + }) + s.consecutiveRefusals.Count = 0 + s.consecutiveRefusals.Last = time.Time{} +} + +func (s *addrStatus) Trim(n int) { + if len(s.results) >= n { + s.results = s.results[len(s.results)-n:] + } +} + +func (s *addrStatus) AddRefusal(at time.Time) { + s.consecutiveRefusals.Count++ + s.consecutiveRefusals.Last = at +} diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go new file mode 100644 index 0000000000..4cfc179e9a --- /dev/null +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -0,0 +1,597 @@ +package basichost + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddrTrackerGetProbe(t *testing.T) { + pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1") + + cl := clock.NewMock() + + t.Run("inprogress probes", func(t *testing.T) { + tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + + tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) + reqs1 := tr.GetProbe() + reqs2 := tr.GetProbe() + require.Equal(t, reqs1, reqs2) + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.NotEmpty(t, reqs) + tr.MarkProbeInProgress(reqs) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + } + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.NotEmpty(t, reqs) + tr.MarkProbeInProgress(reqs) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) + } + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.Empty(t, reqs) + } + }) + + t.Run("probe refusals", func(t *testing.T) { + tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) + var probes [][]autonatv2.Request + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + tr.MarkProbeInProgress(reqs) + probes = append(probes, reqs) + } + // first one rejected second one successful + for i := 0; i < len(probes); i++ { + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Status: pb.DialStatus_OK}, nil) + } + // the second address is validated! + probes = nil + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) + tr.MarkProbeInProgress(reqs) + probes = append(probes, reqs) + } + reqs := tr.GetProbe() + require.Empty(t, reqs) + for i := 0; i < len(probes); i++ { + tr.CompleteProbe(probes[i], autonatv2.Result{}, autonatv2.ErrDialRefused) + } + // all requests refused + reqs = tr.GetProbe() + require.Empty(t, reqs) + + cl.Add(10*time.Minute + 5*time.Second) + reqs = tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) + }) + + t.Run("probe successes", func(t *testing.T) { + tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) + var probes [][]autonatv2.Request + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + tr.MarkProbeInProgress(reqs) + probes = append(probes, reqs) + } + // first one rejected second one successful + for i := 0; i < len(probes); i++ { + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, Status: pb.DialStatus_E_DIAL_ERROR}, nil) + } + // the second address is validated! + probes = nil + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) + tr.MarkProbeInProgress(reqs) + probes = append(probes, reqs) + } + reqs := tr.GetProbe() + require.Empty(t, reqs) + for i := 0; i < len(probes); i++ { + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Status: pb.DialStatus_OK}, nil) + } + // all statueses probed + reqs = tr.GetProbe() + require.Empty(t, reqs) + + cl.Add(1*time.Hour + 5*time.Second) + reqs = tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + tr.MarkProbeInProgress(reqs) + reqs = tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) + }) + + t.Run("reachabilityUpdate", func(t *testing.T) { + tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) + var probes [][]autonatv2.Request + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + tr.MarkProbeInProgress(reqs) + probes = append(probes, reqs) + } + for i := 0; i < len(probes); i++ { + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil) + } + probes = nil + for i := 0; i < 3; i++ { + reqs := tr.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) + tr.MarkProbeInProgress(reqs) + probes = append(probes, reqs) + } + for i := 0; i < len(probes); i++ { + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Status: pb.DialStatus_E_DIAL_ERROR}, nil) + } + + reachable, unreachable := tr.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Equal(t, unreachable, []ma.Multiaddr{pub2}) + + // should expire addrs after 3 hours + cl.Add(3*time.Hour + 1*time.Second) + reachable, unreachable = tr.AppendConfirmedAddrs(nil, nil) + require.Empty(t, reachable) + require.Empty(t, unreachable) + }) +} + +func TestAddrStatus(t *testing.T) { + now := time.Now() + probeResultWindow := maxRecentProbeResultWindow + + type input struct { + At time.Time + Success, Refused bool + } + type testCase struct { + inputs []input + probeCount int + reachability network.Reachability + } + tests := []testCase{ + { + inputs: []input{ + {At: now, Success: true}, + }, + probeCount: 2, + reachability: network.ReachabilityUnknown, + }, + { + inputs: []input{ + {At: now, Success: false}, + {At: now, Success: true}, + {At: now, Success: true}, + {At: now, Success: true}, + }, + probeCount: 1, + reachability: network.ReachabilityPublic, + }, + { + inputs: []input{ + {At: now, Success: true}, + {At: now, Success: false}, + {At: now, Success: false}, + {At: now, Success: false}, + }, + probeCount: 1, + reachability: network.ReachabilityPrivate, + }, + { + inputs: []input{ + {At: now, Success: false}, + {At: now, Success: false}, + {At: now, Success: false}, + {At: now, Success: false}, + {At: now, Success: false}, + {At: now, Success: false}, + {At: now, Success: true}, + {At: now, Success: true}, + {At: now, Success: true}, + {At: now, Success: true}, + }, + probeCount: 0, + reachability: network.ReachabilityPublic, + }, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + s := &addrStatus{Addr: ma.StringCast("/ip4/1.1.1.1/tcp/1")} + for _, inp := range tt.inputs { + if inp.Refused { + s.AddRefusal(now) + } else { + s.AddResult(now, inp.Success) + } + s.Trim(probeResultWindow) + } + require.Equal(t, tt.reachability, s.Reachability()) + require.Equal(t, tt.probeCount, s.ProbeCount(now)) + }) + } +} + +func TestAddrStatusRefused(t *testing.T) { + s := &addrStatus{Addr: ma.StringCast("/ip4/1.1.1.1/tcp/1")} + now := time.Now() + for i := 0; i < maxConsecutiveRefusals-1; i++ { + s.AddRefusal(now) + } + require.Equal(t, s.ProbeCount(now), 3) + s.AddRefusal(now) + require.Equal(t, s.ProbeCount(now), 0) + require.Equal(t, s.ProbeCount(now.Add(addrRefusedProbeInterval+(1*time.Nanosecond))), 1) // +1 to push it over the threshold + + s.AddResult(now, true) + require.Equal(t, s.ProbeCount(now), 2) + require.Equal(t, s.consecutiveRefusals.Count, 0) +} + +type mockAutoNATClient struct { + F func(context.Context, []autonatv2.Request) (autonatv2.Result, error) +} + +func (m mockAutoNATClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + return m.F(ctx, reqs) +} + +var _ autonatv2Client = mockAutoNATClient{} + +func TestAddrReachabilityTracker(t *testing.T) { + pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") + pub2, _ := ma.NewMultiaddr("/ip4/1.1.1.2/tcp/1") + pub3, _ := ma.NewMultiaddr("/ip4/1.1.1.3/tcp/1") + pri, _ := ma.NewMultiaddr("/ip4/192.168.1.1/tcp/1") + + newTracker := func(cli mockAutoNATClient, cl clock.Clock) *addrsReachabilityTracker { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if cl == nil { + cl = clock.New() + } + tr := &addrsReachabilityTracker{ + ctx: ctx, + cancel: cancel, + cli: cli, + newAddrs: make(chan []ma.Multiaddr, 1), + reachabilityUpdateCh: make(chan struct{}, 1), + maxConcurrency: 3, + newAddrsProbeDelay: 0 * time.Second, + addrTracker: newAddrsTracker(cl.Now, maxRecentProbeResultWindow), + clock: cl, + } + err := tr.Start() + require.NoError(t, err) + t.Cleanup(func() { + err := tr.Close() + assert.NoError(t, err) + }) + return tr + } + + t.Run("simple", func(t *testing.T) { + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + for _, req := range reqs { + if req.Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + } else if req.Addr.Equal(pub2) { + return autonatv2.Result{Addr: pub2, Status: pb.DialStatus_E_DIAL_ERROR}, nil + } + } + return autonatv2.Result{}, autonatv2.ErrDialRefused + }, + } + tr := newTracker(mockClient, nil) + tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pri}) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(2 * time.Second): + t.Fatal("expected reachability update") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1) + require.Empty(t, unreachable) + + tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pub2, pri}) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(1 * time.Second): + t.Fatal("unexpected call") + } + reachable, unreachable = tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1) + require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2) + }) + + t.Run("backoff", func(t *testing.T) { + notify := make(chan struct{}, 1) + drainNotify := func() { + for { + select { + case <-notify: + default: + return + } + } + } + + var allow atomic.Bool + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + select { + case notify <- struct{}{}: + default: + } + if !allow.Load() { + return autonatv2.Result{}, autonatv2.ErrNoValidPeers + } + if reqs[0].Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + } + return autonatv2.Result{}, autonatv2.ErrDialRefused + }, + } + + cl := clock.NewMock() + tr := newTracker(mockClient, cl) + + // update addrs and wait for initial checks + tr.UpdateAddrs([]ma.Multiaddr{pub1}) + // need to update clock after the background goroutine processes the new addrs + time.Sleep(100 * time.Millisecond) + cl.Add(1) + select { + case <-tr.reachabilityUpdateCh: + reachable, unreachable := tr.ConfirmedAddrs() + require.Empty(t, reachable) + require.Empty(t, unreachable) + case <-time.After(1 * time.Second): + t.Fatal("unexpected call") + } + + backoffInterval := backoffStartInterval + for i := 0; i < 4; i++ { + drainNotify() + cl.Add(backoffInterval / 2) + select { + case <-notify: + t.Fatal("unexpected call") + case <-time.After(50 * time.Millisecond): + } + cl.Add(backoffInterval/2 + 1) // +1 to push it slightly over the backoff interval + backoffInterval *= 2 + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(1 * time.Second): + t.Fatal("unexpected call") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Empty(t, reachable) + require.Empty(t, unreachable) + } + allow.Store(true) + drainNotify() + cl.Add(backoffInterval + 1) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(1 * time.Second): + t.Fatal("unexpected call") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) + }) + + t.Run("event update", func(t *testing.T) { + // allow minConfidence probes to pass + called := make(chan struct{}, minConfidence) + notify := make(chan struct{}) + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + select { + case <-ctx.Done(): + return autonatv2.Result{}, ctx.Err() + case called <- struct{}{}: + notify <- struct{}{} + } + return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + }, + } + + tr := newTracker(mockClient, clock.New()) + tr.UpdateAddrs([]ma.Multiaddr{pub1}) + for i := 0; i < minConfidence; i++ { + select { + case <-notify: + case <-time.After(1 * time.Second): + t.Fatal("expected call to autonat client") + } + } + select { + case <-tr.reachabilityUpdateCh: + t.Fatal("didn't expect reachability update") + case <-time.After(1 * time.Second): + } + tr.UpdateAddrs([]ma.Multiaddr{pub1}) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(1 * time.Second): + t.Fatal("expected reachability update") + } + }) +} + +func TestRunProbes(t *testing.T) { + pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + pub2 := ma.StringCast("/ip4/1.1.1.1/tcp/2") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + t.Run("backoff on ErrNoValidPeers", func(t *testing.T) { + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + return autonatv2.Result{}, autonatv2.ErrNoValidPeers + }, + } + + addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) + result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) + require.True(t, result) + require.Equal(t, addrTracker.InProgressProbes(), 0) + }) + + t.Run("returns backoff on errTooManyConsecutiveFailures", func(t *testing.T) { + // Create a client that always returns ErrDialRefused + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + return autonatv2.Result{}, errors.New("test error") + }, + } + + addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) + + result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) + require.True(t, result) + require.Equal(t, addrTracker.InProgressProbes(), 0) + }) + + t.Run("quits on cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + block := make(chan struct{}) + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + block <- struct{}{} + return autonatv2.Result{}, nil + }, + } + + addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) + assert.False(t, result) + assert.Equal(t, addrTracker.InProgressProbes(), 0) + }() + + cancel() + time.Sleep(50 * time.Millisecond) // wait for the cancellation to be processed + + outer: + for i := 0; i < defaultMaxConcurrency; i++ { + select { + case <-block: + default: + break outer + } + } + select { + case <-block: + t.Fatal("expected no more requests") + case <-time.After(50 * time.Millisecond): + } + wg.Wait() + }) + + t.Run("handles refusals", func(t *testing.T) { + pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") + + addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) + + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + for _, req := range reqs { + if req.Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + } + } + return autonatv2.Result{}, autonatv2.ErrDialRefused + }, + } + + result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) + require.False(t, result) + + reachable, unreachable := addrTracker.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) + require.Equal(t, addrTracker.InProgressProbes(), 0) + }) + + t.Run("handles completions", func(t *testing.T) { + addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) + + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + for _, req := range reqs { + if req.Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + } + if req.Addr.Equal(pub2) { + return autonatv2.Result{Addr: pub2, Status: pb.DialStatus_E_DIAL_ERROR}, nil + } + } + return autonatv2.Result{}, autonatv2.ErrDialRefused + }, + } + + result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) + require.False(t, result) + + reachable, unreachable := addrTracker.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Equal(t, unreachable, []ma.Multiaddr{pub2}) + require.Equal(t, addrTracker.InProgressProbes(), 0) + }) +} + +func BenchmarkAddrTracker(b *testing.B) { + cl := clock.NewMock() + t := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + + var addrs []ma.Multiaddr + for i := 0; i < 20; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", i))) + } + t.UpdateAddrs(addrs) + b.ReportAllocs() + b.ResetTimer() + p := t.GetProbe() + for i := 0; i < b.N; i++ { + pp := t.GetProbe() + if len(pp) == 0 { + pp = p + } + t.MarkProbeInProgress(pp) + t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Status: pb.DialStatus_OK}, nil) + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index b4db8c2091..a42ad9b2eb 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -236,7 +236,25 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { }); ok { tfl = s.TransportForListening } - h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan) + + if opts.EnableAutoNATv2 { + var mt autonatv2.MetricsTracer + if opts.EnableMetrics { + mt = autonatv2.NewMetricsTracer(opts.PrometheusRegisterer) + } + // keep this on host as it has the server as well as the client + h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer, autonatv2.WithMetricsTracer(mt)) + if err != nil { + return nil, fmt.Errorf("failed to create autonatv2: %w", err) + } + } + + // avoid typed nil errors + var autonatv2Client autonatv2Client + if h.autonatv2 != nil { + autonatv2Client = h.autonatv2 + } + h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan, autonatv2Client) if err != nil { return nil, fmt.Errorf("failed to create address service: %w", err) } @@ -283,17 +301,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h.pings = ping.NewPingService(h) } - if opts.EnableAutoNATv2 { - var mt autonatv2.MetricsTracer - if opts.EnableMetrics { - mt = autonatv2.NewMetricsTracer(opts.PrometheusRegisterer) - } - h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer, autonatv2.WithMetricsTracer(mt)) - if err != nil { - return nil, fmt.Errorf("failed to create autonatv2: %w", err) - } - } - if !h.disableSignedPeerRecord { h.signKey = h.Peerstore().PrivKey(h.ID()) cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) @@ -754,6 +761,16 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { return h.addressManager.DirectAddrs() } +// ReachableAddrs returns all addresses of the host that are reachable from the internet +// as verified by autonatv2. +// +// Experimental: This API may change in the future without deprecation. +// +// Requires AutoNATv2 to be enabled. +func (h *BasicHost) ReachableAddrs() []ma.Multiaddr { + return h.addressManager.ReachableAddrs() +} + func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr { totalSize := 0 for _, a := range addrs { @@ -836,7 +853,6 @@ func (h *BasicHost) Close() error { if h.cmgr != nil { h.cmgr.Close() } - h.addressManager.Close() if h.ids != nil { diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 80a1d3dd62..757d43f27b 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -47,6 +47,7 @@ func TestHostSimple(t *testing.T) { h1.Start() h2, err := NewHost(swarmt.GenSwarm(t), nil) require.NoError(t, err) + defer h2.Close() h2.Start() @@ -211,6 +212,7 @@ func TestAllAddrs(t *testing.T) { // no listen addrs h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil) require.NoError(t, err) + h.Start() defer h.Close() require.Nil(t, h.AllAddrs()) diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index ee4318dd87..717bb65ef4 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -8,8 +8,6 @@ import ( "sync" "time" - "math/rand/v2" - logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" @@ -35,6 +33,8 @@ const ( // maxPeerAddresses is the number of addresses in a dial request the server // will inspect, rest are ignored. maxPeerAddresses = 50 + + defaultThrottlePeerDuration = 2 * time.Minute ) var ( @@ -76,8 +76,10 @@ type AutoNAT struct { srv *server cli *client - mx sync.Mutex - peers *peersMap + mx sync.Mutex + peers map[peer.ID]struct{} + throttlePeer map[peer.ID]time.Time + throttlePeerDuration time.Duration // allowPrivateAddrs enables using private and localhost addresses for reachability checks. // This is only useful for testing. allowPrivateAddrs bool @@ -96,18 +98,21 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, ctx, cancel := context.WithCancel(context.Background()) an := &AutoNAT{ - host: host, - ctx: ctx, - cancel: cancel, - srv: newServer(host, dialerHost, s), - cli: newClient(host), - allowPrivateAddrs: s.allowPrivateAddrs, - peers: newPeersMap(), + host: host, + ctx: ctx, + cancel: cancel, + srv: newServer(host, dialerHost, s), + cli: newClient(host), + allowPrivateAddrs: s.allowPrivateAddrs, + peers: make(map[peer.ID]struct{}), + throttlePeer: make(map[peer.ID]time.Time), + throttlePeerDuration: s.throttlePeerDuration, } return an, nil } func (an *AutoNAT) background(sub event.Subscription) { + ticker := time.NewTicker(10 * time.Minute) for { select { case <-an.ctx.Done(): @@ -123,6 +128,15 @@ func (an *AutoNAT) background(sub event.Subscription) { case event.EvtPeerIdentificationCompleted: an.updatePeer(evt.Peer) } + case <-ticker.C: + now := time.Now() + an.mx.Lock() + for p, t := range an.throttlePeer { + if t.Before(now) { + delete(an.throttlePeer, p) + } + } + an.mx.Unlock() } } } @@ -163,17 +177,25 @@ func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, } } } + now := time.Now() an.mx.Lock() - p := an.peers.GetRand() + var p peer.ID + for pr := range an.peers { + if t := an.throttlePeer[pr]; t.After(now) { + continue + } + p = pr + an.throttlePeer[p] = time.Now().Add(an.throttlePeerDuration) + break + } an.mx.Unlock() if p == "" { return Result{}, ErrNoValidPeers } - res, err := an.cli.GetReachability(ctx, p, reqs) if err != nil { log.Debugf("reachability check with %s failed, err: %s", p, err) - return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err) + return res, fmt.Errorf("reachability check with %s failed: %w", p, err) } log.Debugf("reachability check with %s successful", p) return res, nil @@ -187,49 +209,9 @@ func (an *AutoNAT) updatePeer(p peer.ID) { // and swarm for the current state protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) connectedness := an.host.Network().Connectedness(p) - if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected { - an.peers.Put(p) + if err == nil && connectedness == network.Connected && slices.Contains(protos, DialProtocol) { + an.peers[p] = struct{}{} } else { - an.peers.Delete(p) - } -} - -// peersMap provides random access to a set of peers. This is useful when the map iteration order is -// not sufficiently random. -type peersMap struct { - peerIdx map[peer.ID]int - peers []peer.ID -} - -func newPeersMap() *peersMap { - return &peersMap{ - peerIdx: make(map[peer.ID]int), - peers: make([]peer.ID, 0), - } -} - -func (p *peersMap) GetRand() peer.ID { - if len(p.peers) == 0 { - return "" - } - return p.peers[rand.IntN(len(p.peers))] -} - -func (p *peersMap) Put(pid peer.ID) { - if _, ok := p.peerIdx[pid]; ok { - return - } - p.peers = append(p.peers, pid) - p.peerIdx[pid] = len(p.peers) - 1 -} - -func (p *peersMap) Delete(pid peer.ID) { - idx, ok := p.peerIdx[pid] - if !ok { - return + delete(an.peers, p) } - p.peers[idx] = p.peers[len(p.peers)-1] - p.peerIdx[p.peers[idx]] = idx - p.peers = p.peers[:len(p.peers)-1] - delete(p.peerIdx, pid) } diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 11c8f02195..c2ecaba62e 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -36,6 +36,7 @@ func newAutoNAT(t testing.TB, dialer host.Host, opts ...AutoNATOption) *AutoNAT swarm.WithUDPBlackHoleSuccessCounter(nil), swarm.WithIPv6BlackHoleSuccessCounter(nil)))) } + opts = append([]AutoNATOption{withThrottlePeerDuration(0)}, opts...) an, err := New(h, dialer, opts...) if err != nil { t.Error(err) @@ -74,7 +75,7 @@ func waitForPeer(t testing.TB, a *AutoNAT) { require.Eventually(t, func() bool { a.mx.Lock() defer a.mx.Unlock() - return a.peers.GetRand() != "" + return len(a.peers) != 0 }, 5*time.Second, 100*time.Millisecond) } @@ -526,71 +527,31 @@ func TestEventSubscription(t *testing.T) { require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers.peers) == 1 + return len(an.peers) == 1 }, 5*time.Second, 100*time.Millisecond) idAndConnect(t, an.host, c) require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers.peers) == 2 + return len(an.peers) == 2 }, 5*time.Second, 100*time.Millisecond) an.host.Network().ClosePeer(b.ID()) require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers.peers) == 1 + return len(an.peers) == 1 }, 5*time.Second, 100*time.Millisecond) an.host.Network().ClosePeer(c.ID()) require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers.peers) == 0 + return len(an.peers) == 0 }, 5*time.Second, 100*time.Millisecond) } -func TestPeersMap(t *testing.T) { - emptyPeerID := peer.ID("") - - t.Run("single_item", func(t *testing.T) { - p := newPeersMap() - p.Put("peer1") - p.Delete("peer1") - p.Put("peer1") - require.Equal(t, peer.ID("peer1"), p.GetRand()) - p.Delete("peer1") - require.Equal(t, emptyPeerID, p.GetRand()) - }) - - t.Run("multiple_items", func(t *testing.T) { - p := newPeersMap() - require.Equal(t, emptyPeerID, p.GetRand()) - - allPeers := make(map[peer.ID]bool) - for i := 0; i < 20; i++ { - pid := peer.ID(fmt.Sprintf("peer-%d", i)) - allPeers[pid] = true - p.Put(pid) - } - foundPeers := make(map[peer.ID]bool) - for i := 0; i < 1000; i++ { - pid := p.GetRand() - require.NotEqual(t, emptyPeerID, p) - require.True(t, allPeers[pid]) - foundPeers[pid] = true - if len(foundPeers) == len(allPeers) { - break - } - } - for pid := range allPeers { - p.Delete(pid) - } - require.Equal(t, emptyPeerID, p.GetRand()) - }) -} - func TestAreAddrsConsistency(t *testing.T) { c := &client{ normalizeMultiaddr: func(a ma.Multiaddr) ma.Multiaddr { diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index bbb6145b8c..d71e68d73d 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -111,7 +111,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request case msg.GetDialDataRequest() != nil: if err := ac.validateDialDataRequest(reqs, &msg); err != nil { s.Reset() - return Result{}, fmt.Errorf("invalid dial data request: %w", err) + return Result{}, fmt.Errorf("invalid dial data request: %s %w", s.Conn().RemoteMultiaddr(), err) } // dial data request is valid and we want to send data if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil { @@ -147,7 +147,6 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request if int(resp.AddrIdx) >= len(reqs) { return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs)) } - // wait for nonce from the server var dialBackAddr ma.Multiaddr if resp.GetDialStatus() == pb.DialStatus_OK { diff --git a/p2p/protocol/autonatv2/options.go b/p2p/protocol/autonatv2/options.go index 76cd3735f4..f7cf4b7178 100644 --- a/p2p/protocol/autonatv2/options.go +++ b/p2p/protocol/autonatv2/options.go @@ -13,6 +13,7 @@ type autoNATSettings struct { now func() time.Time amplificatonAttackPreventionDialWait time.Duration metricsTracer MetricsTracer + throttlePeerDuration time.Duration } func defaultSettings() *autoNATSettings { @@ -25,6 +26,7 @@ func defaultSettings() *autoNATSettings { dataRequestPolicy: amplificationAttackPrevention, amplificatonAttackPreventionDialWait: 3 * time.Second, now: time.Now, + throttlePeerDuration: defaultThrottlePeerDuration, } } @@ -65,3 +67,10 @@ func withAmplificationAttackPreventionDialWait(d time.Duration) AutoNATOption { return nil } } + +func withThrottlePeerDuration(d time.Duration) AutoNATOption { + return func(s *autoNATSettings) error { + s.throttlePeerDuration = d + return nil + } +} diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index 05e0bdd9fd..92bf42cef4 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -522,6 +522,6 @@ func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool if err != nil { return true } - dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr + dialIP, _ := manet.ToIP(dialAddr) // must be an IP multiaddr return !connIP.Equal(dialIP) } From 5801a052736bf84d5f95bb63c8639541bb467688 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 14 Apr 2025 16:54:24 +0530 Subject: [PATCH 02/15] fix race in addrs update --- p2p/host/basic/addrs_manager.go | 135 +++++++++++++++++++------------- 1 file changed, 80 insertions(+), 55 deletions(-) diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index 3c09c98bb8..4114a43558 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -2,6 +2,7 @@ package basichost import ( "context" + "errors" "fmt" "net" "slices" @@ -110,11 +111,7 @@ func (a *addrsManager) Start() error { } } - err := a.background() - if err != nil { - return err - } - return nil + return a.startBackgroundWorker() } func (a *addrsManager) Close() { @@ -151,7 +148,7 @@ func (a *addrsManager) triggerAddrsUpdate() { } } -func (a *addrsManager) background() error { +func (a *addrsManager) startBackgroundWorker() error { autoRelayAddrsSub, err := a.bus.Subscribe(new(event.EvtAutoRelayAddrsUpdated), eventbus.Name("addrs-manager")) if err != nil { return fmt.Errorf("error subscribing to auto relay addrs: %s", err) @@ -159,12 +156,26 @@ func (a *addrsManager) background() error { autonatReachabilitySub, err := a.bus.Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("addrs-manager")) if err != nil { - return fmt.Errorf("error subscribing to autonat reachability: %s", err) + err1 := autoRelayAddrsSub.Close() + if err1 != nil { + err1 = fmt.Errorf("error closign autorelaysub: %w", err1) + } + err = fmt.Errorf("error subscribing to autonat reachability: %s", err) + return errors.Join(err, err1) } emitter, err := a.bus.Emitter(new(event.EvtHostReachableAddrsChanged), eventbus.Stateful) if err != nil { - return fmt.Errorf("error creating host reachable addrs emitter: %w", err) + err1 := autoRelayAddrsSub.Close() + if err1 != nil { + err1 = fmt.Errorf("error closing autorelaysub: %w", err1) + } + err2 := autonatReachabilitySub.Close() + if err2 != nil { + err2 = fmt.Errorf("error closing autonat reachability: %w", err1) + } + err = fmt.Errorf("error subscribing to autonat reachability: %s", err) + return errors.Join(err, err1, err2) } // update relay addrs in case we're private @@ -183,66 +194,73 @@ func (a *addrsManager) background() error { } default: } + // update addresses before starting the worker loop. This ensures that any address updates + // before calling addrsManager.Start are correctly reported after Start returns. a.updateAddrs() - a.wg.Add(1) go func() { defer a.wg.Done() - defer func() { - err := autoRelayAddrsSub.Close() - if err != nil { - log.Warnf("error closing auto relay addrs sub: %s", err) - } - }() - defer func() { - err := autonatReachabilitySub.Close() - if err != nil { - log.Warnf("error closing autonat reachability sub: %s", err) + a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter) + }() + return nil +} + +func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription, emitter event.Emitter) { + defer func() { + err := autoRelayAddrsSub.Close() + if err != nil { + log.Warnf("error closing auto relay addrs sub: %s", err) + } + }() + defer func() { + err := autonatReachabilitySub.Close() + if err != nil { + log.Warnf("error closing autonat reachability sub: %s", err) + } + }() + + ticker := time.NewTicker(addrChangeTickrInterval) + defer ticker.Stop() + var previousAddrs hostAddrs + for { + currAddrs := a.updateAddrs() + a.notifyAddrsChanged(emitter, previousAddrs, currAddrs) + previousAddrs = currAddrs + select { + case <-ticker.C: + case <-a.triggerAddrsUpdateChan: + case e := <-autoRelayAddrsSub.Out(): + if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { + a.updateRelayAddrs(evt.RelayAddrs) } - }() - - ticker := time.NewTicker(addrChangeTickrInterval) - defer ticker.Stop() - var previousAddrs hostAddrs - for { - currAddrs := a.updateAddrs() - a.notifyAddrsChanged(emitter, previousAddrs, currAddrs) - previousAddrs = currAddrs - select { - case <-ticker.C: - case <-a.triggerAddrsUpdateChan: - case e := <-autoRelayAddrsSub.Out(): - if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { - a.updateRelayAddrs(evt.RelayAddrs) - } - case <-a.triggerReachabilityUpdate: - case e := <-autonatReachabilitySub.Out(): - if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { - a.hostReachability.Store(&evt.Reachability) - } - case <-a.ctx.Done(): - return + case <-a.triggerReachabilityUpdate: + case e := <-autonatReachabilitySub.Out(): + if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { + a.hostReachability.Store(&evt.Reachability) } + case <-a.ctx.Done(): + return } - }() - return nil + } } func (a *addrsManager) updateAddrs() hostAddrs { + a.addrsMx.Lock() + defer a.addrsMx.Unlock() + localAddrs := a.getLocalAddrs() var currReachableAddrs, currUnreachableAddrs []ma.Multiaddr if a.addrsReachabilityTracker != nil { currReachableAddrs, currUnreachableAddrs = a.getConfirmedAddrs(localAddrs) } - currAddrs := a.getAddrs(slices.Clone(localAddrs), a.RelayAddrs()) + currAddrs := a.getAddrs(slices.Clone(localAddrs), a.getRelayAddrsUnlocked()) - // maybe we can avoid this clone? - a.addrsMx.Lock() - a.currentAddrs.addrs = append(a.currentAddrs.addrs[:0], currAddrs...) - a.currentAddrs.localAddrs = append(a.currentAddrs.localAddrs[:0], localAddrs...) - a.currentAddrs.reachableAddrs = append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...) - a.currentAddrs.unreachableAddrs = append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...) - a.addrsMx.Unlock() + a.currentAddrs = hostAddrs{ + addrs: append(a.currentAddrs.addrs[:0], currAddrs...), + localAddrs: append(a.currentAddrs.localAddrs[:0], localAddrs...), + reachableAddrs: append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...), + unreachableAddrs: append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...), + } return hostAddrs{ localAddrs: localAddrs, @@ -254,12 +272,13 @@ func (a *addrsManager) updateAddrs() hostAddrs { func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) { if areAddrsDifferent(previous.localAddrs, current.localAddrs) { - log.Debugf("host addresses updated: %s", current.localAddrs) + log.Debugf("host local addresses updated: %s", current.localAddrs) if a.addrsReachabilityTracker != nil { a.addrsReachabilityTracker.UpdateAddrs(current.localAddrs) } } if areAddrsDifferent(previous.addrs, current.addrs) { + log.Debugf("host addresses updated: %s", current.localAddrs) select { case a.addrsUpdatedChan <- struct{}{}: default: @@ -268,6 +287,7 @@ func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, curre if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) || areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) { + log.Debugf("host reachable addrs updated: %s", current.localAddrs) if err := emitter.Emit(event.EvtHostReachableAddrsChanged{ Reachable: slices.Clone(current.reachableAddrs), Unreachable: slices.Clone(current.unreachableAddrs), @@ -284,7 +304,7 @@ func (a *addrsManager) Addrs() []ma.Multiaddr { return a.getAddrs(a.DirectAddrs(), a.RelayAddrs()) } -// mutates localAddrs +// getAddrs returns the node's dialable addresses. Mutates localAddrs func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr { addrs := localAddrs rch := a.hostReachability.Load() @@ -302,6 +322,7 @@ func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multi slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) return addrs } + func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr { addrs := a.DirectAddrs() addrs = slices.Clone(a.addrsFactory(addrs)) @@ -331,6 +352,10 @@ func (a *addrsManager) ReachableAddrs() []ma.Multiaddr { func (a *addrsManager) RelayAddrs() []ma.Multiaddr { a.addrsMx.RLock() defer a.addrsMx.RUnlock() + return a.getRelayAddrsUnlocked() +} + +func (a *addrsManager) getRelayAddrsUnlocked() []ma.Multiaddr { return slices.Clone(a.relayAddrs) } @@ -342,7 +367,7 @@ func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { reachableAddrs, unreachableAddrs = a.addrsReachabilityTracker.ConfirmedAddrs() - // Only include relevant host addresses as the reachability manager may have + // Only include host addresses as the reachability manager may have // a stale view of host's addresses. reachableAddrs = slices.DeleteFunc(reachableAddrs, func(a ma.Multiaddr) bool { return !contains(localAddrs, a) From 33c34c92521ff9352ce16ed337d07436b0f8acfb Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 14 Apr 2025 17:00:31 +0530 Subject: [PATCH 03/15] autonatv2: randomize peers selected for calls --- p2p/host/basic/addrs_manager_test.go | 4 +- p2p/host/basic/addrs_reachability_tracker.go | 88 +++++++++++-------- .../basic/addrs_reachability_tracker_test.go | 30 +++---- p2p/protocol/autonatv2/autonat.go | 83 ++++++++++++++--- p2p/protocol/autonatv2/autonat_test.go | 24 ++--- p2p/protocol/autonatv2/client.go | 33 ++++--- p2p/protocol/autonatv2/server_test.go | 68 +++++++------- 7 files changed, 196 insertions(+), 134 deletions(-) diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index 9b01149dae..260c3de07d 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -455,9 +455,9 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) { AutoNATClient: mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { if reqs[0].Addr.Equal(publicQUIC) { - return autonatv2.Result{Addr: reqs[0].Addr, Status: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: reqs[0].Addr, DialStatus: pb.DialStatus_OK}, nil } else if reqs[0].Addr.Equal(publicTCP) { - return autonatv2.Result{Addr: reqs[0].Addr, Status: pb.DialStatus_E_DIAL_ERROR}, nil + return autonatv2.Result{Addr: reqs[0].Addr, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil } t.Errorf("received invalid request for addr: %+v", reqs[0]) return autonatv2.Result{}, errors.New("invalid") diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index e0b385dc4b..68cacfe8b1 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -12,7 +12,6 @@ import ( "github.com/benbjohnson/clock" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" - "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -213,30 +212,36 @@ var errTooManyConsecutiveFailures = errors.New("too many consecutive failures") // errTooManyConsecutiveFailures in case of many consecutive failures type errCountingClient struct { autonatv2Client - MaxConsecutiveErrors int - mx sync.Mutex - consecutiveErrors int + MaxConsecutiveErrors int + mx sync.Mutex + consecutiveErrors int + loggedPrivateAddrsError bool } func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { res, err := c.autonatv2Client.GetReachability(ctx, reqs) c.mx.Lock() defer c.mx.Unlock() - if err == nil || errors.Is(err, autonatv2.ErrDialRefused) || errors.Is(err, autonatv2.ErrNoValidPeers) { - c.consecutiveErrors = 0 - } else { + if err != nil { c.consecutiveErrors++ if c.consecutiveErrors > c.MaxConsecutiveErrors { err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err) } + // This is hacky, but we do want to log this error + if !c.loggedPrivateAddrsError && errors.Is(err, autonatv2.ErrPrivateAddrs) { + log.Errorf("private IP addr in autonatv2 request: %s", err) + c.loggedPrivateAddrsError = true // log it only once. This should never happen + } + } else { + c.consecutiveErrors = 0 } return res, err } type probeResponse struct { - Requests []autonatv2.Request - Result autonatv2.Result - Err error + Req []autonatv2.Request + Res autonatv2.Result + Err error } const maxConsecutiveErrors = 20 @@ -261,7 +266,7 @@ func runProbes(ctx context.Context, concurrency int, addrsTracker *addrsProbeTra ctx, cancel := context.WithTimeout(ctx, 30*time.Second) res, err := client.GetReachability(ctx, reqs) cancel() - resultsCh <- probeResponse{Requests: reqs, Result: res, Err: err} + resultsCh <- probeResponse{Req: reqs, Res: res, Err: err} } }() } @@ -274,8 +279,8 @@ outer: case jc <- nextProbe: addrsTracker.MarkProbeInProgress(nextProbe) case resp := <-resultsCh: - addrsTracker.CompleteProbe(resp.Requests, resp.Result, resp.Err) - if errors.Is(resp.Err, autonatv2.ErrNoValidPeers) || errors.Is(resp.Err, errTooManyConsecutiveFailures) { + addrsTracker.CompleteProbe(resp.Req, resp.Res, resp.Err) + if isErrorPersistent(resp.Err) { backoff = true break outer } @@ -291,8 +296,8 @@ outer: close(jobsCh) for addrsTracker.InProgressProbes() > 0 { resp := <-resultsCh - addrsTracker.CompleteProbe(resp.Requests, resp.Result, resp.Err) - if errors.Is(resp.Err, autonatv2.ErrNoValidPeers) || errors.Is(resp.Err, errTooManyConsecutiveFailures) { + addrsTracker.CompleteProbe(resp.Req, resp.Res, resp.Err) + if isErrorPersistent(resp.Err) { backoff = true } } @@ -300,10 +305,19 @@ outer: return backoff } +// isErrorPersistent returns whether the error will repeat on future probes for a while +func isErrorPersistent(err error) bool { + if err == nil { + return false + } + return errors.Is(err, autonatv2.ErrPrivateAddrs) || errors.Is(err, autonatv2.ErrNoPeers) || + errors.Is(err, errTooManyConsecutiveFailures) +} + // addrsProbeTracker tracks reachability for a set of addresses. This struct decides the priority order of // addresses for testing reachability. // -// To execute the probes with a client use the `runProbes` function. +// Use the `runProbes` function to execute the probes with an autonatv2 client. // // Probes returned by `GetProbe` should be marked as in progress using `MarkProbeInProgress` // before being executed. @@ -344,6 +358,7 @@ func (t *addrsProbeTracker) AppendConfirmedAddrs(reachable, unreachable []ma.Mul return reachable, unreachable } +// UpdateAddrs updates the tracked addrs func (t *addrsProbeTracker) UpdateAddrs(addrs []ma.Multiaddr) { t.mx.Lock() defer t.mx.Unlock() @@ -367,6 +382,9 @@ func (t *addrsProbeTracker) UpdateAddrs(addrs []ma.Multiaddr) { t.addrs = addrs } +// GetProbe returns the next probe. Returns empty slice in case there are no more probes. +// Probes that are run against an autonatv2 client should be marked in progress with +// `MarkProbeInProgress` before running. func (t *addrsProbeTracker) GetProbe() []autonatv2.Request { t.mx.Lock() defer t.mx.Unlock() @@ -391,6 +409,7 @@ func (t *addrsProbeTracker) GetProbe() []autonatv2.Request { } // MarkProbeInProgress should be called when a probe is started. +// All in progress probes *MUST* be completed with `CompleteProbe` func (t *addrsProbeTracker) MarkProbeInProgress(reqs []autonatv2.Request) { if len(reqs) == 0 { return @@ -428,24 +447,23 @@ func (t *addrsProbeTracker) CompleteProbe(reqs []autonatv2.Request, res autonatv delete(t.inProgressProbes, primaryAddrKey) } - // request failed + // nothing to do if the request errored. if err != nil { - // request refused - if errors.Is(err, autonatv2.ErrDialRefused) { - for _, req := range reqs { - if status, ok := t.statuses[string(req.Addr.Bytes())]; ok { - status.AddRefusal(now) - } + return + } + + // request failed + if res.AllAddrsRefused { + for _, req := range reqs { + if status, ok := t.statuses[string(req.Addr.Bytes())]; ok { + status.AddRefusal(now) } } return } // mark addresses that were skipped as refused - for _, req := range reqs { - if req.Addr.Equal(res.Addr) { - break - } + for _, req := range reqs[:res.Idx] { if status, ok := t.statuses[string(req.Addr.Bytes())]; ok { status.AddRefusal(now) } @@ -453,13 +471,13 @@ func (t *addrsProbeTracker) CompleteProbe(reqs []autonatv2.Request, res autonatv // record the result for the probed address if status, ok := t.statuses[string(res.Addr.Bytes())]; ok { - switch res.Status { - case pb.DialStatus_OK: + switch res.Reachability { + case network.ReachabilityPublic: status.AddResult(now, true) - case pb.DialStatus_E_DIAL_ERROR: + case network.ReachabilityPrivate: status.AddResult(now, false) default: - log.Debug("unexpected dial status", res.Addr, res.Status) + log.Debug("unexpected dial status", res.Addr) } status.Trim(t.recentProbeResultWindow) } @@ -579,9 +597,9 @@ func (s *addrStatus) resultCounts() (successes, failures int) { return successes, failures } -func (s *addrStatus) ExpireBefore(before time.Time) { +func (s *addrStatus) ExpireBefore(expiry time.Time) { s.results = slices.DeleteFunc(s.results, func(pr probeResult) bool { - return pr.Time.Before(before) + return pr.Time.Before(expiry) }) } @@ -595,8 +613,8 @@ func (s *addrStatus) AddResult(at time.Time, success bool) { } func (s *addrStatus) Trim(n int) { - if len(s.results) >= n { - s.results = s.results[len(s.results)-n:] + if len(s.results) > n { + s.results = slices.Delete(s.results, 0, len(s.results)-n) } } diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index 4cfc179e9a..c80909344d 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -61,7 +61,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { } // first one rejected second one successful for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Status: pb.DialStatus_OK}, nil) + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_OK}, nil) } // the second address is validated! probes = nil @@ -97,7 +97,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { } // first one rejected second one successful for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, Status: pb.DialStatus_E_DIAL_ERROR}, nil) + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil) } // the second address is validated! probes = nil @@ -110,7 +110,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { reqs := tr.GetProbe() require.Empty(t, reqs) for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Status: pb.DialStatus_OK}, nil) + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_OK}, nil) } // all statueses probed reqs = tr.GetProbe() @@ -135,7 +135,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { probes = append(probes, reqs) } for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil) + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil) } probes = nil for i := 0; i < 3; i++ { @@ -145,7 +145,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { probes = append(probes, reqs) } for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Status: pb.DialStatus_E_DIAL_ERROR}, nil) + tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil) } reachable, unreachable := tr.AppendConfirmedAddrs(nil, nil) @@ -297,9 +297,9 @@ func TestAddrReachabilityTracker(t *testing.T) { F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for _, req := range reqs { if req.Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil } else if req.Addr.Equal(pub2) { - return autonatv2.Result{Addr: pub2, Status: pb.DialStatus_E_DIAL_ERROR}, nil + return autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil } } return autonatv2.Result{}, autonatv2.ErrDialRefused @@ -347,10 +347,10 @@ func TestAddrReachabilityTracker(t *testing.T) { default: } if !allow.Load() { - return autonatv2.Result{}, autonatv2.ErrNoValidPeers + return autonatv2.Result{}, autonatv2.ErrNoPeers } if reqs[0].Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil } return autonatv2.Result{}, autonatv2.ErrDialRefused }, @@ -418,7 +418,7 @@ func TestAddrReachabilityTracker(t *testing.T) { case called <- struct{}{}: notify <- struct{}{} } - return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil }, } @@ -453,7 +453,7 @@ func TestRunProbes(t *testing.T) { t.Run("backoff on ErrNoValidPeers", func(t *testing.T) { mockClient := mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { - return autonatv2.Result{}, autonatv2.ErrNoValidPeers + return autonatv2.Result{}, autonatv2.ErrNoPeers }, } @@ -530,7 +530,7 @@ func TestRunProbes(t *testing.T) { F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for _, req := range reqs { if req.Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil } } return autonatv2.Result{}, autonatv2.ErrDialRefused @@ -554,10 +554,10 @@ func TestRunProbes(t *testing.T) { F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for _, req := range reqs { if req.Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, Status: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil } if req.Addr.Equal(pub2) { - return autonatv2.Result{Addr: pub2, Status: pb.DialStatus_E_DIAL_ERROR}, nil + return autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil } } return autonatv2.Result{}, autonatv2.ErrDialRefused @@ -592,6 +592,6 @@ func BenchmarkAddrTracker(b *testing.B) { pp = p } t.MarkProbeInProgress(pp) - t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Status: pb.DialStatus_OK}, nil) + t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, DialStatus: pb.DialStatus_OK}, nil) } } diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index 717bb65ef4..045ff577b5 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "iter" + "math/rand/v2" "slices" "sync" "time" @@ -13,7 +15,6 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -38,8 +39,10 @@ const ( ) var ( - ErrNoValidPeers = errors.New("no valid peers for autonat v2") - ErrDialRefused = errors.New("dial refused") + // ErrNoPeers is returned when the client knows no autonatv2 servers. + ErrNoPeers = errors.New("no peers for autonat v2") + // ErrPrivateAddrs is returned when the request has private IP addresses. + ErrPrivateAddrs = errors.New("private addresses cannot be verified with autonatv2") log = logging.Logger("autonatv2") ) @@ -56,10 +59,12 @@ type Request struct { type Result struct { // Addr is the dialed address Addr ma.Multiaddr - // Reachability of the dialed address + // Idx is the index of the address that was dialed + Idx int + // Reachability is the reachability for `Addr` Reachability network.Reachability - // Status is the outcome of the dialback - Status pb.DialStatus + // AllAddrsRefused is true when the server refused to dial all the addresses in the request. + AllAddrsRefused bool } // AutoNAT implements the AutoNAT v2 client and server. @@ -77,7 +82,7 @@ type AutoNAT struct { cli *client mx sync.Mutex - peers map[peer.ID]struct{} + peers *peersMap throttlePeer map[peer.ID]time.Time throttlePeerDuration time.Duration // allowPrivateAddrs enables using private and localhost addresses for reachability checks. @@ -104,7 +109,7 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, srv: newServer(host, dialerHost, s), cli: newClient(host), allowPrivateAddrs: s.allowPrivateAddrs, - peers: make(map[peer.ID]struct{}), + peers: newPeersMap(), throttlePeer: make(map[peer.ID]time.Time), throttlePeerDuration: s.throttlePeerDuration, } @@ -173,14 +178,14 @@ func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, if !an.allowPrivateAddrs { for _, r := range reqs { if !manet.IsPublicAddr(r.Addr) { - return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr) + return Result{}, fmt.Errorf("%w: %s", ErrPrivateAddrs, r.Addr) } } } - now := time.Now() an.mx.Lock() + now := time.Now() var p peer.ID - for pr := range an.peers { + for pr := range an.peers.Shuffled() { if t := an.throttlePeer[pr]; t.After(now) { continue } @@ -190,7 +195,7 @@ func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, } an.mx.Unlock() if p == "" { - return Result{}, ErrNoValidPeers + return Result{}, ErrNoPeers } res, err := an.cli.GetReachability(ctx, p, reqs) if err != nil { @@ -210,8 +215,58 @@ func (an *AutoNAT) updatePeer(p peer.ID) { protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) connectedness := an.host.Network().Connectedness(p) if err == nil && connectedness == network.Connected && slices.Contains(protos, DialProtocol) { - an.peers[p] = struct{}{} + an.peers.Put(p) } else { - delete(an.peers, p) + an.peers.Delete(p) + } +} + +// peersMap provides random access to a set of peers. This is useful when the map iteration order is +// not sufficiently random. +type peersMap struct { + peerIdx map[peer.ID]int + peers []peer.ID +} + +func newPeersMap() *peersMap { + return &peersMap{ + peerIdx: make(map[peer.ID]int), + peers: make([]peer.ID, 0), + } +} + +// Shuffled iterates over the map in random order +func (p *peersMap) Shuffled() iter.Seq[peer.ID] { + n := len(p.peers) + start := 0 + if n > 0 { + start = rand.IntN(len(p.peers)) + } + return func(yield func(peer.ID) bool) { + for i := range n { + if !yield(p.peers[(i+start)%n]) { + return + } + } + } +} + +func (p *peersMap) Put(id peer.ID) { + if _, ok := p.peerIdx[id]; ok { + return + } + p.peers = append(p.peers, id) + p.peerIdx[id] = len(p.peers) - 1 +} + +func (p *peersMap) Delete(id peer.ID) { + idx, ok := p.peerIdx[id] + if !ok { + return } + delete(p.peerIdx, id) + p.peers[idx] = p.peers[len(p.peers)-1] + p.peerIdx[p.peers[idx]] = idx + p.peers[len(p.peers)-1] = "" + p.peers = p.peers[:len(p.peers)-1] } diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index c2ecaba62e..791c0f9f83 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -75,7 +75,7 @@ func waitForPeer(t testing.TB, a *AutoNAT) { require.Eventually(t, func() bool { a.mx.Lock() defer a.mx.Unlock() - return len(a.peers) != 0 + return len(a.peers.peers) != 0 }, 5*time.Second, 100*time.Millisecond) } @@ -155,19 +155,6 @@ func TestClientServerError(t *testing.T) { }, errorStr: "invalid msg type", }, - { - handler: func(s network.Stream) { - w := pbio.NewDelimitedWriter(s) - assert.NoError(t, w.WriteMsg( - &pb.Message{Msg: &pb.Message_DialResponse{ - DialResponse: &pb.DialResponse{ - Status: pb.DialResponse_E_DIAL_REFUSED, - }, - }}, - )) - }, - errorStr: ErrDialRefused.Error(), - }, } for i, tc := range tests { @@ -508,7 +495,6 @@ func TestClientDialBacks(t *testing.T) { } else { require.NoError(t, err) require.Equal(t, res.Reachability, network.ReachabilityPublic) - require.Equal(t, res.Status, pb.DialStatus_OK) } }) } @@ -527,28 +513,28 @@ func TestEventSubscription(t *testing.T) { require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers) == 1 + return len(an.peers.peers) == 1 }, 5*time.Second, 100*time.Millisecond) idAndConnect(t, an.host, c) require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers) == 2 + return len(an.peers.peers) == 2 }, 5*time.Second, 100*time.Millisecond) an.host.Network().ClosePeer(b.ID()) require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers) == 1 + return len(an.peers.peers) == 1 }, 5*time.Second, 100*time.Millisecond) an.host.Network().ClosePeer(c.ID()) require.Eventually(t, func() bool { an.mx.Lock() defer an.mx.Unlock() - return len(an.peers) == 0 + return len(an.peers.peers) == 0 }, 5*time.Second, 100*time.Millisecond) } diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index d71e68d73d..f15cb9ecbe 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -136,7 +136,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request // E_DIAL_REFUSED has implication for deciding future address verificiation priorities // wrap a distinct error for convenient errors.Is usage if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED { - return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused) + return Result{AllAddrsRefused: true}, nil } return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(), pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())]) @@ -162,7 +162,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request return ac.newResult(resp, reqs, dialBackAddr) } -func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error { +func (*client) validateDialDataRequest(reqs []Request, msg *pb.Message) error { idx := int(msg.GetDialDataRequest().AddrIdx) if idx >= len(reqs) { // invalid address index return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs)) @@ -178,9 +178,13 @@ func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) { idx := int(resp.AddrIdx) + if idx >= len(reqs) { + // This should have been validated by this point, but checking this is cheap. + return Result{}, fmt.Errorf("addrs index(%d) greater than len(reqs)(%d)", idx, len(reqs)) + } addr := reqs[idx].Addr - var rch network.Reachability + rch := network.ReachabilityUnknown switch resp.DialStatus { case pb.DialStatus_OK: if !ac.areAddrsConsistent(dialBackAddr, addr) { @@ -190,17 +194,16 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr) } rch = network.ReachabilityPublic - case pb.DialStatus_E_DIAL_ERROR: - rch = network.ReachabilityPrivate case pb.DialStatus_E_DIAL_BACK_ERROR: - if ac.areAddrsConsistent(dialBackAddr, addr) { - // We received the dial back but the server claims the dial back errored. - // As long as we received the correct nonce in dial back it is safe to assume - // that we are public. - rch = network.ReachabilityPublic - } else { - rch = network.ReachabilityUnknown + if !ac.areAddrsConsistent(dialBackAddr, addr) { + return Result{}, fmt.Errorf("dial-back stream error: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr) } + // We received the dial back but the server claims the dial back errored. + // As long as we received the correct nonce in dial back it is safe to assume + // that we are public. + rch = network.ReachabilityPublic + case pb.DialStatus_E_DIAL_ERROR: + rch = network.ReachabilityPrivate default: // Unexpected response code. Discard the response and fail. log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus) @@ -209,8 +212,8 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr return Result{ Addr: addr, + Idx: idx, Reachability: rch, - Status: resp.DialStatus, }, nil } @@ -299,6 +302,7 @@ func (ac *client) handleDialBack(s network.Stream) { } w := pbio.NewDelimitedWriter(s) res := pb.DialBackResponse{} + // TODO: Check what happens on sending empty if err := w.WriteMsg(&res); err != nil { log.Debugf("failed to write dialback response: %s", err) s.Reset() @@ -306,7 +310,8 @@ func (ac *client) handleDialBack(s network.Stream) { } func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool { - if connLocalAddr == nil || dialedAddr == nil { + // TODO: Check this n times + if len(connLocalAddr) == 0 || len(dialedAddr) == 0 { return false } connLocalAddr = ac.normalizeMultiaddr(connLocalAddr) diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 6664f1c397..73c89a55fc 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "math" + "strings" "sync" "sync/atomic" "testing" @@ -46,8 +47,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("black holed addr", func(t *testing.T) { @@ -64,8 +65,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"), SendDialData: true, }}) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("private addrs", func(t *testing.T) { @@ -76,8 +77,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("relay addrs", func(t *testing.T) { @@ -89,8 +90,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { res, err := c.GetReachability(context.Background(), newTestRequests( []ma.Multiaddr{ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p/%s/p2p-circuit/p2p/%s", c.host.ID(), c.srv.dialerHost.ID()))}, true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("no addr", func(t *testing.T) { @@ -113,8 +114,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("msg too large", func(t *testing.T) { @@ -178,8 +179,8 @@ func TestServerDataRequest(t *testing.T) { require.Equal(t, Result{ Addr: quicAddr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) // Small messages should be rejected for dial data @@ -191,14 +192,11 @@ func TestServerDataRequest(t *testing.T) { func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { const concurrentRequests = 5 - // server will skip all tcp addresses - dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) - - doneChan := make(chan struct{}) - an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( + stallChan := make(chan struct{}) + an := newAutoNAT(t, nil, allowPrivateAddrs, withDataRequestPolicy( // stall all allowed requests func(s network.Stream, dialAddr ma.Multiaddr) bool { - <-doneChan + <-stallChan return true }), WithServerRateLimit(10, 10, 10, concurrentRequests), @@ -207,16 +205,18 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { defer an.Close() defer an.host.Close() - c := newAutoNAT(t, nil, allowPrivateAddrs) + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + c := newAutoNAT(t, dialer, allowPrivateAddrs) defer c.Close() defer c.host.Close() idAndWait(t, c, an) errChan := make(chan error) - const N = 10 - // num concurrentRequests will stall and N will fail - for i := 0; i < concurrentRequests+N; i++ { + const n = 10 + // num concurrentRequests will stall and n will fail + for i := 0; i < concurrentRequests+n; i++ { go func() { _, err := c.GetReachability(context.Background(), []Request{{Addr: c.host.Addrs()[0], SendDialData: false}}) errChan <- err @@ -224,17 +224,20 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { } // check N failures - for i := 0; i < N; i++ { + for i := 0; i < n; i++ { select { case err := <-errChan: require.Error(t, err) + if !strings.Contains(err.Error(), "stream reset") && !strings.Contains(err.Error(), "E_REQUEST_REJECTED") { + t.Fatalf("invalid error: %s expected: stream reset or E_REQUEST_REJECTED", err) + } case <-time.After(10 * time.Second): - t.Fatalf("expected %d errors: got: %d", N, i) + t.Fatalf("expected %d errors: got: %d", n, i) } } + close(stallChan) // complete stalled requests // check concurrentRequests failures, as we won't send dial data - close(doneChan) for i := 0; i < concurrentRequests; i++ { select { case err := <-errChan: @@ -290,8 +293,8 @@ func TestServerDataRequestJitter(t *testing.T) { require.Equal(t, Result{ Addr: quicAddr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) if took > 500*time.Millisecond { return @@ -320,8 +323,8 @@ func TestServerDial(t *testing.T) { require.NoError(t, err) require.Equal(t, Result{ Addr: unreachableAddr, + Idx: 0, Reachability: network.ReachabilityPrivate, - Status: pb.DialStatus_E_DIAL_ERROR, }, res) }) @@ -330,16 +333,16 @@ func TestServerDial(t *testing.T) { require.NoError(t, err) require.Equal(t, Result{ Addr: hostAddrs[0], + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) for _, addr := range c.host.Addrs() { res, err := c.GetReachability(context.Background(), newTestRequests([]ma.Multiaddr{addr}, false)) require.NoError(t, err) require.Equal(t, Result{ Addr: addr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) } }) @@ -347,12 +350,8 @@ func TestServerDial(t *testing.T) { t.Run("dialback error", func(t *testing.T) { c.host.RemoveStreamHandler(DialBackProtocol) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) - require.NoError(t, err) - require.Equal(t, Result{ - Addr: hostAddrs[0], - Reachability: network.ReachabilityUnknown, - Status: pb.DialStatus_E_DIAL_BACK_ERROR, - }, res) + require.ErrorContains(t, err, "dial-back stream error") + require.Equal(t, Result{}, res) }) } @@ -396,7 +395,6 @@ func TestRateLimiter(t *testing.T) { cl.AdvanceBy(10 * time.Second) require.True(t, r.Accept("peer3")) - } func TestRateLimiterConcurrentRequests(t *testing.T) { From 3b184e59b853754d4be836c8edebb01108ef06b0 Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 15 Apr 2025 17:54:19 +0530 Subject: [PATCH 04/15] limit probes per address per minute --- p2p/host/basic/addrs_reachability_tracker.go | 357 ++++++++++-------- .../basic/addrs_reachability_tracker_test.go | 22 +- p2p/protocol/autonatv2/autonat.go | 22 +- 3 files changed, 225 insertions(+), 176 deletions(-) diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index 68cacfe8b1..0b6022a689 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -41,7 +41,7 @@ type addrsReachabilityTracker struct { reachabilityUpdateCh chan struct{} maxConcurrency int newAddrsProbeDelay time.Duration - addrTracker *addrsProbeTracker + addrTracker *probeManager newAddrs chan []ma.Multiaddr clock clock.Clock @@ -64,7 +64,7 @@ func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh ch cancel: cancel, cli: client, reachabilityUpdateCh: reachabilityUpdateCh, - addrTracker: newAddrsTracker(cl.Now, maxRecentProbeResultWindow), + addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow), newAddrsProbeDelay: 1 * time.Second, maxConcurrency: defaultMaxConcurrency, newAddrs: make(chan []ma.Multiaddr, 1), @@ -252,7 +252,7 @@ const maxConsecutiveErrors = 20 // - context is completed // - there are too many consecutive failures from the client // - the client has no valid peers to probe -func runProbes(ctx context.Context, concurrency int, addrsTracker *addrsProbeTracker, client autonatv2Client) bool { +func runProbes(ctx context.Context, concurrency int, addrsTracker *probeManager, client autonatv2Client) bool { client = &errCountingClient{autonatv2Client: client, MaxConsecutiveErrors: maxConsecutiveErrors} resultsCh := make(chan probeResponse, 2*concurrency) // enough buffer to allow all worker goroutines to exit quickly @@ -314,16 +314,45 @@ func isErrorPersistent(err error) bool { errors.Is(err, errTooManyConsecutiveFailures) } -// addrsProbeTracker tracks reachability for a set of addresses. This struct decides the priority order of +const ( + // maxProbeResultTTL is the maximum time to keep probe results for an address + maxProbeResultTTL = 3 * time.Hour + // maxProbeInterval is the maximum interval between probes for an address + maxProbeInterval = 1 * time.Hour + // addrRefusedProbeInterval is the interval to probe addresses that have been refused + // these are generally addresses with newer transports for which we don't have many peers + // capable of dialing the transport + addrRefusedProbeInterval = 10 * time.Minute + // maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which + // we wait for `addrRefusedProbeInterval` before probing again + maxConsecutiveRefusals = 5 + // confidence is the absolute difference between the number of successes and failures for an address + // targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval` + // before probing again + targetConfidence = 3 + // minConfidence is the confidence threshold for an address to be considered reachable or unreachable + // confidence is the absolute difference between the number of successes and failures for an address + minConfidence = 2 + // maxRecentProbeResultWindow is the maximum number of recent probe results to consider for a single address + // + // +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure + // and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to + // targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval. + maxRecentProbeResultWindow = targetConfidence + 2 +) + +// probeManager tracks reachability for a set of addresses. This struct decides the priority order of // addresses for testing reachability. // // Use the `runProbes` function to execute the probes with an autonatv2 client. // // Probes returned by `GetProbe` should be marked as in progress using `MarkProbeInProgress` // before being executed. -type addrsProbeTracker struct { - now func() time.Time - recentProbeResultWindow int +type probeManager struct { + now func() time.Time + recentProbeResultWindow int + ProbeInterval time.Duration + MaxProbesPerAddrsPerInterval int mx sync.Mutex inProgressProbes map[string]int // addr -> count @@ -332,8 +361,8 @@ type addrsProbeTracker struct { addrs []ma.Multiaddr } -func newAddrsTracker(now func() time.Time, recentProbeResultWindow int) *addrsProbeTracker { - return &addrsProbeTracker{ +func newProbeManager(now func() time.Time, recentProbeResultWindow int) *probeManager { + return &probeManager{ statuses: make(map[string]*addrStatus), inProgressProbes: make(map[string]int), now: now, @@ -342,13 +371,12 @@ func newAddrsTracker(now func() time.Time, recentProbeResultWindow int) *addrsPr } // AppendConfirmedAddrs appends the current confirmed reachable and unreachable addresses. -func (t *addrsProbeTracker) AppendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { - t.mx.Lock() - defer t.mx.Unlock() +func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + m.mx.Lock() + defer m.mx.Unlock() - t.gc() - for _, as := range t.statuses { - switch as.Reachability() { + for _, as := range m.statuses { + switch as.outcomes.Reachability() { case network.ReachabilityPublic: reachable = append(reachable, as.Addr) case network.ReachabilityPrivate: @@ -359,15 +387,15 @@ func (t *addrsProbeTracker) AppendConfirmedAddrs(reachable, unreachable []ma.Mul } // UpdateAddrs updates the tracked addrs -func (t *addrsProbeTracker) UpdateAddrs(addrs []ma.Multiaddr) { - t.mx.Lock() - defer t.mx.Unlock() +func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) { + m.mx.Lock() + defer m.mx.Unlock() for _, addr := range addrs { - if _, ok := t.statuses[string(addr.Bytes())]; !ok { - t.statuses[string(addr.Bytes())] = &addrStatus{Addr: addr} + if _, ok := m.statuses[string(addr.Bytes())]; !ok { + m.statuses[string(addr.Bytes())] = &addrStatus{Addr: addr} } } - for k, s := range t.statuses { + for k, s := range m.statuses { found := false for _, a := range addrs { if a.Equal(s.Addr) { @@ -376,28 +404,28 @@ func (t *addrsProbeTracker) UpdateAddrs(addrs []ma.Multiaddr) { } } if !found { - delete(t.statuses, k) + delete(m.statuses, k) } } - t.addrs = addrs + m.addrs = addrs } // GetProbe returns the next probe. Returns empty slice in case there are no more probes. // Probes that are run against an autonatv2 client should be marked in progress with // `MarkProbeInProgress` before running. -func (t *addrsProbeTracker) GetProbe() []autonatv2.Request { - t.mx.Lock() - defer t.mx.Unlock() +func (m *probeManager) GetProbe() []autonatv2.Request { + m.mx.Lock() + defer m.mx.Unlock() reqs := make([]autonatv2.Request, 0, maxAddrsPerRequest) - now := t.now() - for _, a := range t.addrs { + now := m.now() + for _, a := range m.addrs { akey := string(a.Bytes()) - pc := t.statuses[akey].ProbeCount(now) + pc := m.probeCount(m.statuses[akey], now) if pc == 0 { continue } - if len(reqs) == 0 && t.inProgressProbes[akey] >= pc { + if len(reqs) == 0 && m.inProgressProbes[akey] >= pc { continue } reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true}) @@ -410,41 +438,41 @@ func (t *addrsProbeTracker) GetProbe() []autonatv2.Request { // MarkProbeInProgress should be called when a probe is started. // All in progress probes *MUST* be completed with `CompleteProbe` -func (t *addrsProbeTracker) MarkProbeInProgress(reqs []autonatv2.Request) { +func (m *probeManager) MarkProbeInProgress(reqs []autonatv2.Request) { if len(reqs) == 0 { return } - t.mx.Lock() - defer t.mx.Unlock() - t.inProgressProbes[string(reqs[0].Addr.Bytes())]++ - t.inProgressProbesTotal++ + m.mx.Lock() + defer m.mx.Unlock() + m.inProgressProbes[string(reqs[0].Addr.Bytes())]++ + m.inProgressProbesTotal++ } // InProgressProbes returns the number of probes that are currently in progress. -func (t *addrsProbeTracker) InProgressProbes() int { - t.mx.Lock() - defer t.mx.Unlock() - return t.inProgressProbesTotal +func (m *probeManager) InProgressProbes() int { + m.mx.Lock() + defer m.mx.Unlock() + return m.inProgressProbesTotal } // CompleteProbe should be called when a probe completes. -func (t *addrsProbeTracker) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Result, err error) { - now := t.now() +func (m *probeManager) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Result, err error) { + now := m.now() if len(reqs) == 0 { // should never happen return } - t.mx.Lock() - defer t.mx.Unlock() + m.mx.Lock() + defer m.mx.Unlock() // decrement in-progress count for the first address primaryAddrKey := string(reqs[0].Addr.Bytes()) - t.inProgressProbes[primaryAddrKey]-- - t.inProgressProbesTotal-- - if t.inProgressProbes[primaryAddrKey] <= 0 { - delete(t.inProgressProbes, primaryAddrKey) + m.inProgressProbes[primaryAddrKey]-- + m.inProgressProbesTotal-- + if m.inProgressProbes[primaryAddrKey] <= 0 { + delete(m.inProgressProbes, primaryAddrKey) } // nothing to do if the request errored. @@ -452,91 +480,81 @@ func (t *addrsProbeTracker) CompleteProbe(reqs []autonatv2.Request, res autonatv return } + expireBefore := now.Add(-maxProbeInterval) // request failed if res.AllAddrsRefused { - for _, req := range reqs { - if status, ok := t.statuses[string(req.Addr.Bytes())]; ok { - status.AddRefusal(now) - } + if s, ok := m.statuses[primaryAddrKey]; ok { + s.lastRefusalTime = now + s.consecutiveRefusals++ } return } - // mark addresses that were skipped as refused - for _, req := range reqs[:res.Idx] { - if status, ok := t.statuses[string(req.Addr.Bytes())]; ok { - status.AddRefusal(now) - } + // mark only the primary status as refused + if s, ok := m.statuses[primaryAddrKey]; ok { + s.lastRefusalTime = now + s.consecutiveRefusals++ } // record the result for the probed address - if status, ok := t.statuses[string(res.Addr.Bytes())]; ok { - switch res.Reachability { - case network.ReachabilityPublic: - status.AddResult(now, true) - case network.ReachabilityPrivate: - status.AddResult(now, false) - default: - log.Debug("unexpected dial status", res.Addr) + if s, ok := m.statuses[string(res.Addr.Bytes())]; ok { + s.outcomes.AddDialOutcomeAndExpire(now, res.Reachability, m.recentProbeResultWindow, expireBefore) + s.probeTimes = append(s.probeTimes, now) + } +} + +func (m *probeManager) probeCount(s *addrStatus, now time.Time) int { + if s.consecutiveRefusals >= maxConsecutiveRefusals { + if now.Sub(s.lastRefusalTime) < addrRefusedProbeInterval { + return 0 } - status.Trim(t.recentProbeResultWindow) + // reset this + s.lastRefusalTime = time.Time{} + s.consecutiveRefusals = 0 } + + // Don't probe if we have probed too many times recently + if m.recentProbesCount(s, now) >= m.MaxProbesPerAddrsPerInterval { + return 0 + } + + return s.outcomes.ProbeCount(now) } -func (t *addrsProbeTracker) gc() { - expireBefore := t.now().Add(-maxProbeResultTTL) - for _, s := range t.statuses { - s.ExpireBefore(expireBefore) +func (m *probeManager) recentProbesCount(s *addrStatus, now time.Time) int { + cnt := 0 + for _, t := range slices.Backward(s.probeTimes) { + if now.Sub(t) > m.ProbeInterval { + break + } + cnt++ } + return cnt } -type probeResult struct { - Time time.Time +type dialOutcome struct { Success bool + At time.Time } -const ( - // maxProbeResultTTL is the maximum time to keep probe results for an address - maxProbeResultTTL = 3 * time.Hour - // maxProbeInterval is the maximum interval between probes for an address - maxProbeInterval = 1 * time.Hour - // addrRefusedProbeInterval is the interval to probe addresses that have been refused - // these are generally addresses with newer transports for which we don't have many peers - // capable of dialing the transport - addrRefusedProbeInterval = 10 * time.Minute - // maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which - // we wait for `addrRefusedProbeInterval` before probing again - maxConsecutiveRefusals = 5 - // confidence is the absolute difference between the number of successes and failures for an address - // targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval` - // before probing again - targetConfidence = 3 - // minConfidence is the confidence threshold for an address to be considered reachable or unreachable - // confidence is the absolute difference between the number of successes and failures for an address - minConfidence = 2 - // maxRecentProbeResultWindow is the maximum number of recent probe results to consider for a single address - // - // +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure - // and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to - // targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval. - maxRecentProbeResultWindow = targetConfidence + 2 -) - type addrStatus struct { Addr ma.Multiaddr - results []probeResult - consecutiveRefusals struct { - Count int - Last time.Time - } + lastRefusalTime time.Time + consecutiveRefusals int + probeTimes []time.Time + outcomes *addrOutcomes +} + +type addrOutcomes struct { + outcomes []dialOutcome } -func (s *addrStatus) Reachability() network.Reachability { - successes, failures := s.resultCounts() - return s.reachability(successes, failures) +func (o *addrOutcomes) Reachability() network.Reachability { + successes, failures := o.outcomeCounts() + return o.computeReachability(successes, failures) } -func (*addrStatus) reachability(success, failures int) network.Reachability { +func (*addrOutcomes) computeReachability(success, failures int) network.Reachability { if success-failures >= minConfidence { return network.ReachabilityPublic } @@ -546,48 +564,87 @@ func (*addrStatus) reachability(success, failures int) network.Reachability { return network.ReachabilityUnknown } -func (s *addrStatus) ProbeCount(now time.Time) int { - // if we have had too many consecutive refusals, probe after a small wait. - if s.consecutiveRefusals.Count >= maxConsecutiveRefusals { - if s.consecutiveRefusals.Last.Add(addrRefusedProbeInterval).Before(now) { - return 1 +func (o *addrOutcomes) numProbesInInterval(now time.Time, probeInterval time.Duration) int { + cnt := 0 + for _, v := range slices.Backward(o.outcomes) { + if now.Sub(v.At) > probeInterval { + break } - return 0 + cnt++ } + return cnt +} - successes, failures := s.resultCounts() - cnt := 0 - if successes >= failures { - cnt = targetConfidence - (successes - failures) +func (o *addrOutcomes) ProbeCount(now time.Time) int { + successes, failures := o.outcomeCounts() + confidence := successes - failures + if confidence < 0 { + confidence = -confidence } - if failures >= successes { - cnt = targetConfidence - (failures - successes) + cnt := targetConfidence - confidence + if cnt > 0 { + return cnt } - if cnt <= 0 { - if len(s.results) == 0 { - return 0 + // we have enough confirmations, see if we should still retest + + // There are no confirmations. This should never happen. In case there are no confirmations, + // the confidence logic above should require a few probes. + if len(o.outcomes) == 0 { + return 0 + } + lastOutcome := o.outcomes[len(o.outcomes)-1] + // If the last probe result is old, we need to retest + if now.Sub(lastOutcome.At) > maxProbeInterval { + return 1 + } + // if the last probe result was different from reachability, probe again. + switch o.computeReachability(successes, failures) { + case network.ReachabilityPublic: + if !lastOutcome.Success { + return 1 } - if s.results[len(s.results)-1].Time.Add(maxProbeInterval).Before(now) { + case network.ReachabilityPrivate: + if lastOutcome.Success { return 1 } - // Last probe result was different from reachability. Probe again. - switch s.reachability(successes, failures) { - case network.ReachabilityPublic: - if !s.results[len(s.results)-1].Success { - return 1 - } - case network.ReachabilityPrivate: - if s.results[len(s.results)-1].Success { - return 1 - } + default: + // this should never happen. no reachability => confidence is low + return 1 + } + return 0 +} + +func (o *addrOutcomes) AddDialOutcomeAndExpire(at time.Time, rch network.Reachability, windowSize int, expireBefore time.Time) { + success := false + switch rch { + case network.ReachabilityPublic: + success = true + case network.ReachabilityPrivate: + success = false + default: + return // don't store the outcome if reachability is unknown + } + o.outcomes = append(o.outcomes, dialOutcome{At: at, Success: success}) + if len(o.outcomes) > windowSize { + o.outcomes = slices.Delete(o.outcomes, 0, len(o.outcomes)-windowSize) + } + o.removeBefore(expireBefore) +} + +// removeBefore removes outcomes before t +func (o *addrOutcomes) removeBefore(t time.Time) { + st := 0 + var ot dialOutcome + for st, ot = range o.outcomes { + if ot.At.After(t) { + break } - return 0 } - return cnt + o.outcomes = slices.Delete(o.outcomes, 0, st) } -func (s *addrStatus) resultCounts() (successes, failures int) { - for _, r := range s.results { +func (o *addrOutcomes) outcomeCounts() (successes, failures int) { + for _, r := range o.outcomes { if r.Success { successes++ } else { @@ -596,29 +653,3 @@ func (s *addrStatus) resultCounts() (successes, failures int) { } return successes, failures } - -func (s *addrStatus) ExpireBefore(expiry time.Time) { - s.results = slices.DeleteFunc(s.results, func(pr probeResult) bool { - return pr.Time.Before(expiry) - }) -} - -func (s *addrStatus) AddResult(at time.Time, success bool) { - s.results = append(s.results, probeResult{ - Success: success, - Time: at, - }) - s.consecutiveRefusals.Count = 0 - s.consecutiveRefusals.Last = time.Time{} -} - -func (s *addrStatus) Trim(n int) { - if len(s.results) > n { - s.results = slices.Delete(s.results, 0, len(s.results)-n) - } -} - -func (s *addrStatus) AddRefusal(at time.Time) { - s.consecutiveRefusals.Count++ - s.consecutiveRefusals.Last = at -} diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index c80909344d..dcaf033b62 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -25,7 +25,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { cl := clock.NewMock() t.Run("inprogress probes", func(t *testing.T) { - tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) reqs1 := tr.GetProbe() @@ -50,7 +50,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { }) t.Run("probe refusals", func(t *testing.T) { - tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) var probes [][]autonatv2.Request for i := 0; i < 3; i++ { @@ -86,7 +86,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { }) t.Run("probe successes", func(t *testing.T) { - tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) var probes [][]autonatv2.Request for i := 0; i < 3; i++ { @@ -125,7 +125,7 @@ func TestAddrTrackerGetProbe(t *testing.T) { }) t.Run("reachabilityUpdate", func(t *testing.T) { - tr := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) var probes [][]autonatv2.Request for i := 0; i < 3; i++ { @@ -280,7 +280,7 @@ func TestAddrReachabilityTracker(t *testing.T) { reachabilityUpdateCh: make(chan struct{}, 1), maxConcurrency: 3, newAddrsProbeDelay: 0 * time.Second, - addrTracker: newAddrsTracker(cl.Now, maxRecentProbeResultWindow), + addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow), clock: cl, } err := tr.Start() @@ -457,7 +457,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) require.True(t, result) @@ -472,7 +472,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) @@ -490,7 +490,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) var wg sync.WaitGroup wg.Add(1) @@ -523,7 +523,7 @@ func TestRunProbes(t *testing.T) { t.Run("handles refusals", func(t *testing.T) { pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") - addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ @@ -547,7 +547,7 @@ func TestRunProbes(t *testing.T) { }) t.Run("handles completions", func(t *testing.T) { - addrTracker := newAddrsTracker(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ @@ -576,7 +576,7 @@ func TestRunProbes(t *testing.T) { func BenchmarkAddrTracker(b *testing.B) { cl := clock.NewMock() - t := newAddrsTracker(cl.Now, maxRecentProbeResultWindow) + t := newProbeManager(cl.Now, maxRecentProbeResultWindow) var addrs []ma.Multiaddr for i := 0; i < 20; i++ { diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index 045ff577b5..66c124890d 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -132,6 +132,8 @@ func (an *AutoNAT) background(sub event.Subscription) { an.updatePeer(evt.Peer) case event.EvtPeerIdentificationCompleted: an.updatePeer(evt.Peer) + default: + log.Errorf("unexpected event: %T", e) } case <-ticker.C: now := time.Now() @@ -175,12 +177,21 @@ func (an *AutoNAT) Close() { // GetReachability makes a single dial request for checking reachability for requested addresses func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) { + var filteredReqs []Request if !an.allowPrivateAddrs { + filteredReqs = make([]Request, 0, len(reqs)) for _, r := range reqs { - if !manet.IsPublicAddr(r.Addr) { - return Result{}, fmt.Errorf("%w: %s", ErrPrivateAddrs, r.Addr) + if manet.IsPublicAddr(r.Addr) { + filteredReqs = append(filteredReqs, r) + } else { + log.Errorf("private address in reachability check: %s", r.Addr) } } + if len(filteredReqs) == 0 { + return Result{}, ErrPrivateAddrs + } + } else { + filteredReqs = reqs } an.mx.Lock() now := time.Now() @@ -202,6 +213,13 @@ func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, log.Debugf("reachability check with %s failed, err: %s", p, err) return res, fmt.Errorf("reachability check with %s failed: %w", p, err) } + // restore the correct index in case we'd filtered private addresses + for i, r := range reqs { + if r.Addr.Equal(res.Addr) { + res.Idx = i + break + } + } log.Debugf("reachability check with %s successful", p) return res, nil } From 058baab3dfb22f5f98260fea2a47452bac3c5370 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 16 Apr 2025 17:29:06 +0530 Subject: [PATCH 05/15] fix event ordering --- p2p/host/basic/addrs_reachability_tracker.go | 80 ++++++++++++-------- 1 file changed, 50 insertions(+), 30 deletions(-) diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index 0b6022a689..c37ac2a7d2 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -103,40 +103,37 @@ func (r *addrsReachabilityTracker) background() error { var task reachabilityTask var backoffInterval time.Duration var reachable, unreachable []ma.Multiaddr // used to avoid allocations + var prevReachable, prevUnreachable []ma.Multiaddr for { + var resetInterval time.Duration select { case <-timer.C: if task.RespCh == nil { task = r.refreshReachability() } - timer.Reset(defaultResetInterval) + resetInterval = defaultResetInterval case backoff := <-task.RespCh: task = reachabilityTask{} if backoff { backoffInterval = newBackoffInterval(backoffInterval) - } else { - backoffInterval = 0 } - reachable, unreachable = r.appendConfirmedAddrsAndNotify(reachable[:0], unreachable[:0]) - timer.Reset(backoffInterval) + reachable, unreachable = r.appendConfirmedAddrs(reachable[:0], unreachable[:0]) + resetInterval = 0 case addrs := <-r.newAddrs: if task.RespCh != nil { task.Cancel() - <-task.RespCh + backoff := <-task.RespCh task = reachabilityTask{} - // We must send the event here. If there are no new addrs in this event we may not probe - // again for a while delaying any reachability updates. - reachable, unreachable = r.appendConfirmedAddrsAndNotify(reachable[:0], unreachable[:0]) - } - addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { - return !manet.IsPublicAddr(a) - }) - if len(addrs) > maxTrackedAddrs { - log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs) - addrs = addrs[:maxTrackedAddrs] + if backoff { + backoffInterval = newBackoffInterval(backoffInterval) + } + // We must update reachable addrs here. + // If there are no new addrs in this event we may not probe again for a while + // and suppress any reachability updates. + reachable, unreachable = r.appendConfirmedAddrs(reachable[:0], unreachable[:0]) } - r.addrTracker.UpdateAddrs(addrs) - timer.Reset(r.newAddrsProbeDelay) + resetInterval = r.newAddrsProbeDelay + r.updateTrackedAddrs(addrs) case <-r.ctx.Done(): if task.RespCh != nil { task.Cancel() @@ -145,28 +142,51 @@ func (r *addrsReachabilityTracker) background() error { } return } + + if areAddrsDifferent(prevReachable, r.reachableAddrs) || areAddrsDifferent(prevUnreachable, r.unreachableAddrs) { + reachable, unreachable = r.appendConfirmedAddrs(reachable[:0], unreachable[:0]) + } + prevReachable, prevUnreachable = r.reachableAddrs, r.unreachableAddrs + if backoffInterval > resetInterval { + resetInterval = backoffInterval + } + timer.Reset(resetInterval) } }() return nil } -func (r *addrsReachabilityTracker) appendConfirmedAddrsAndNotify(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { +func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + r.mx.Lock() + defer r.mx.Unlock() + return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs) +} + +func (r *addrsReachabilityTracker) appendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { reachable, unreachable = r.addrTracker.AppendConfirmedAddrs(reachable, unreachable) r.mx.Lock() r.reachableAddrs = append(r.reachableAddrs[:0], reachable...) r.unreachableAddrs = append(r.unreachableAddrs[:0], unreachable...) r.mx.Unlock() + return reachable, unreachable +} + +func (r *addrsReachabilityTracker) notify() { select { case r.reachabilityUpdateCh <- struct{}{}: default: } - return reachable, unreachable } -func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) { - r.mx.Lock() - defer r.mx.Unlock() - return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs) +func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) { + addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { + return !manet.IsPublicAddr(a) + }) + if len(addrs) > maxTrackedAddrs { + log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs) + addrs = addrs[:maxTrackedAddrs] + } + r.addrTracker.UpdateAddrs(addrs) } const ( @@ -222,7 +242,7 @@ func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv res, err := c.autonatv2Client.GetReachability(ctx, reqs) c.mx.Lock() defer c.mx.Unlock() - if err != nil { + if err != nil && !errors.Is(err, context.Canceled) { // ignore canceled errors, they're not errors from autonatv2 c.consecutiveErrors++ if c.consecutiveErrors > c.MaxConsecutiveErrors { err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err) @@ -375,12 +395,12 @@ func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiadd m.mx.Lock() defer m.mx.Unlock() - for _, as := range m.statuses { - switch as.outcomes.Reachability() { + for _, a := range m.addrs { + switch m.statuses[string(a.Bytes())].outcomes.Reachability() { case network.ReachabilityPublic: - reachable = append(reachable, as.Addr) + reachable = append(reachable, a) case network.ReachabilityPrivate: - unreachable = append(unreachable, as.Addr) + unreachable = append(unreachable, a) } } return reachable, unreachable @@ -542,7 +562,7 @@ type addrStatus struct { lastRefusalTime time.Time consecutiveRefusals int probeTimes []time.Time - outcomes *addrOutcomes + outcomes *addrOutcomes // TODO: no pointer? } type addrOutcomes struct { From 2dd86620da1509acf6749471e6a85fb5872390c9 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 16 Apr 2025 20:40:15 +0530 Subject: [PATCH 06/15] fix backoff calculation --- p2p/host/basic/addrs_reachability_tracker.go | 79 +++++++++++--------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index c37ac2a7d2..f50fa0e8a2 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -97,42 +97,51 @@ func (r *addrsReachabilityTracker) background() error { go func() { defer r.wg.Done() - timer := r.clock.Timer(time.Duration(math.MaxInt64)) - defer timer.Stop() + // probeTicker is used to trigger probes at regular intervals + probeTicker := r.clock.Ticker(defaultResetInterval) + defer probeTicker.Stop() + + // probeTimer is used to trigger probes at specific times + probeTimer := r.clock.Timer(time.Duration(math.MaxInt64)) + defer probeTimer.Stop() + nextProbeTime := time.Time{} var task reachabilityTask var backoffInterval time.Duration - var reachable, unreachable []ma.Multiaddr // used to avoid allocations - var prevReachable, prevUnreachable []ma.Multiaddr + var currReachable, currUnreachable, prevReachable, prevUnreachable []ma.Multiaddr for { - var resetInterval time.Duration select { - case <-timer.C: + case <-probeTicker.C: if task.RespCh == nil { task = r.refreshReachability() } - resetInterval = defaultResetInterval + nextProbeTime = time.Time{} + case <-probeTimer.C: + if task.RespCh == nil { + task = r.refreshReachability() + } + nextProbeTime = time.Time{} case backoff := <-task.RespCh: task = reachabilityTask{} + // On completion, start the next probe immediately, or wait for backoff + // In case there are no further probes, the reachability tracker will return an empty task, + // which hangs forever. Eventually, we'll refresh again when the ticker fires. if backoff { backoffInterval = newBackoffInterval(backoffInterval) + } else { + backoffInterval = -1 * time.Second // negative to trigger next probe immediately } - reachable, unreachable = r.appendConfirmedAddrs(reachable[:0], unreachable[:0]) - resetInterval = 0 + nextProbeTime = r.clock.Now().Add(backoffInterval) case addrs := <-r.newAddrs: - if task.RespCh != nil { + if task.RespCh != nil { // cancel running task. task.Cancel() - backoff := <-task.RespCh + <-task.RespCh // ignore backoff from cancelled task task = reachabilityTask{} - if backoff { - backoffInterval = newBackoffInterval(backoffInterval) - } - // We must update reachable addrs here. - // If there are no new addrs in this event we may not probe again for a while - // and suppress any reachability updates. - reachable, unreachable = r.appendConfirmedAddrs(reachable[:0], unreachable[:0]) } - resetInterval = r.newAddrsProbeDelay + newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay) + if nextProbeTime.Before(newAddrsNextTime) { + nextProbeTime = newAddrsNextTime + } r.updateTrackedAddrs(addrs) case <-r.ctx.Done(): if task.RespCh != nil { @@ -143,14 +152,15 @@ func (r *addrsReachabilityTracker) background() error { return } - if areAddrsDifferent(prevReachable, r.reachableAddrs) || areAddrsDifferent(prevUnreachable, r.unreachableAddrs) { - reachable, unreachable = r.appendConfirmedAddrs(reachable[:0], unreachable[:0]) + currReachable, currUnreachable = r.appendConfirmedAddrs(currReachable[:0], currUnreachable[:0]) + if areAddrsDifferent(prevReachable, currReachable) || areAddrsDifferent(prevUnreachable, currUnreachable) { + r.notify() } - prevReachable, prevUnreachable = r.reachableAddrs, r.unreachableAddrs - if backoffInterval > resetInterval { - resetInterval = backoffInterval + prevReachable = append(prevReachable[:0], currReachable...) + prevUnreachable = append(prevUnreachable[:0], currUnreachable...) + if !nextProbeTime.IsZero() { + probeTimer.Reset(nextProbeTime.Sub(r.clock.Now())) } - timer.Reset(resetInterval) } }() return nil @@ -191,11 +201,12 @@ func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) { const ( backoffStartInterval = 5 * time.Second - maxBackoffInterval = 2 * defaultResetInterval + // maxBackoffInterval should be shorter or equal to defaultResetInterval as we probe every reset interval at least. + maxBackoffInterval = defaultResetInterval ) func newBackoffInterval(current time.Duration) time.Duration { - if current == 0 { + if current <= 0 { return backoffStartInterval } current *= 2 @@ -205,6 +216,7 @@ func newBackoffInterval(current time.Duration) time.Duration { return current } +// reachabilityTask is a task to refresh reachability. type reachabilityTask struct { Cancel context.CancelFunc RespCh chan bool @@ -410,20 +422,17 @@ func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiadd func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) { m.mx.Lock() defer m.mx.Unlock() + + slices.SortStableFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + for _, addr := range addrs { if _, ok := m.statuses[string(addr.Bytes())]; !ok { m.statuses[string(addr.Bytes())] = &addrStatus{Addr: addr} } } for k, s := range m.statuses { - found := false - for _, a := range addrs { - if a.Equal(s.Addr) { - found = true - break - } - } - if !found { + _, ok := slices.BinarySearchFunc(addrs, s.Addr, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + if !ok { delete(m.statuses, k) } } From 9e35bc1fceda359bb25f109544e4824765b73540 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 18 Apr 2025 13:17:30 +0530 Subject: [PATCH 07/15] fix bug with timer --- p2p/host/basic/addrs_reachability_tracker.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index f50fa0e8a2..7d39dfb83e 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -112,10 +112,10 @@ func (r *addrsReachabilityTracker) background() error { for { select { case <-probeTicker.C: - if task.RespCh == nil { + // don't start a probe if we have a scheduled probe + if task.RespCh == nil && nextProbeTime.IsZero() { task = r.refreshReachability() } - nextProbeTime = time.Time{} case <-probeTimer.C: if task.RespCh == nil { task = r.refreshReachability() From 386c7125e3b22b0237453a819115a7b7e4c1e376 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 20 Apr 2025 20:16:41 +0530 Subject: [PATCH 08/15] update tests --- p2p/host/basic/addrs_manager_test.go | 10 +- p2p/host/basic/addrs_reachability_tracker.go | 232 ++++---- .../basic/addrs_reachability_tracker_test.go | 529 ++++++++++-------- p2p/protocol/autonatv2/autonat_test.go | 6 + p2p/protocol/autonatv2/client.go | 21 +- 5 files changed, 436 insertions(+), 362 deletions(-) diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index 260c3de07d..69d264f8f6 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -11,7 +11,6 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" - "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/assert" @@ -436,14 +435,11 @@ func TestAddrsManager(t *testing.T) { } func TestAddrsManagerReachabilityEvent(t *testing.T) { - // Setup test addresses publicQUIC, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1") publicTCP, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") - // Create a new event bus bus := eventbus.NewBus() - // Subscribe to EvtHostReachableAddrsChanged events sub, err := bus.Subscribe(new(event.EvtHostReachableAddrsChanged)) require.NoError(t, err) defer sub.Close() @@ -455,9 +451,9 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) { AutoNATClient: mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { if reqs[0].Addr.Equal(publicQUIC) { - return autonatv2.Result{Addr: reqs[0].Addr, DialStatus: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil } else if reqs[0].Addr.Equal(publicTCP) { - return autonatv2.Result{Addr: reqs[0].Addr, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPrivate}, nil } t.Errorf("received invalid request for addr: %+v", reqs[0]) return autonatv2.Result{}, errors.New("invalid") @@ -468,11 +464,9 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) { reachableAddrs := []ma.Multiaddr{publicQUIC} unreachableAddrs := []ma.Multiaddr{publicTCP} - // No new event should be received select { case e := <-sub.Out(): evt := e.(event.EvtHostReachableAddrsChanged) - // Verify the event contains the expected addresses require.ElementsMatch(t, reachableAddrs, evt.Reachable) require.ElementsMatch(t, unreachableAddrs, evt.Unreachable) require.ElementsMatch(t, reachableAddrs, am.ReachableAddrs()) diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index 7d39dfb83e..4023108c3a 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -29,6 +29,8 @@ const ( maxTrackedAddrs = 50 // defaultMaxConcurrency is the default number of concurrent workers for reachability checks defaultMaxConcurrency = 5 + // newAddrsProbeDelay is the delay before probing new addr's reachability. + newAddrsProbeDelay = 1 * time.Second ) type addrsReachabilityTracker struct { @@ -64,8 +66,8 @@ func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh ch cancel: cancel, cli: client, reachabilityUpdateCh: reachabilityUpdateCh, - addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow), - newAddrsProbeDelay: 1 * time.Second, + addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow, defaultResetInterval, 10), + newAddrsProbeDelay: newAddrsProbeDelay, maxConcurrency: defaultMaxConcurrency, newAddrs: make(chan []ma.Multiaddr, 1), clock: cl, @@ -76,6 +78,12 @@ func (r *addrsReachabilityTracker) UpdateAddrs(addrs []ma.Multiaddr) { r.newAddrs <- slices.Clone(addrs) } +func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + r.mx.Lock() + defer r.mx.Unlock() + return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs) +} + func (r *addrsReachabilityTracker) Start() error { r.wg.Add(1) err := r.background() @@ -91,7 +99,11 @@ func (r *addrsReachabilityTracker) Close() error { return nil } -const defaultResetInterval = 5 * time.Minute +const ( + defaultResetInterval = 5 * time.Minute + maxBackoffInterval = 5 * time.Minute + backoffStartInterval = 5 * time.Second +) func (r *addrsReachabilityTracker) background() error { go func() { @@ -166,10 +178,15 @@ func (r *addrsReachabilityTracker) background() error { return nil } -func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) { - r.mx.Lock() - defer r.mx.Unlock() - return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs) +func newBackoffInterval(current time.Duration) time.Duration { + if current <= 0 { + return backoffStartInterval + } + current *= 2 + if current > maxBackoffInterval { + return maxBackoffInterval + } + return current } func (r *addrsReachabilityTracker) appendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { @@ -199,23 +216,6 @@ func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) { r.addrTracker.UpdateAddrs(addrs) } -const ( - backoffStartInterval = 5 * time.Second - // maxBackoffInterval should be shorter or equal to defaultResetInterval as we probe every reset interval at least. - maxBackoffInterval = defaultResetInterval -) - -func newBackoffInterval(current time.Duration) time.Duration { - if current <= 0 { - return backoffStartInterval - } - current *= 2 - if current > maxBackoffInterval { - return maxBackoffInterval - } - return current -} - // reachabilityTask is a task to refresh reachability. type reachabilityTask struct { Cancel context.CancelFunc @@ -241,7 +241,7 @@ func (r *addrsReachabilityTracker) refreshReachability() reachabilityTask { var errTooManyConsecutiveFailures = errors.New("too many consecutive failures") // errCountingClient counts errors from autonatv2Client and wraps the errors in response with a -// errTooManyConsecutiveFailures in case of many consecutive failures +// errTooManyConsecutiveFailures in case of persistent failures from autonatv2 module. type errCountingClient struct { autonatv2Client MaxConsecutiveErrors int @@ -347,10 +347,6 @@ func isErrorPersistent(err error) bool { } const ( - // maxProbeResultTTL is the maximum time to keep probe results for an address - maxProbeResultTTL = 3 * time.Hour - // maxProbeInterval is the maximum interval between probes for an address - maxProbeInterval = 1 * time.Hour // addrRefusedProbeInterval is the interval to probe addresses that have been refused // these are generally addresses with newer transports for which we don't have many peers // capable of dialing the transport @@ -371,6 +367,10 @@ const ( // and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to // targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval. maxRecentProbeResultWindow = targetConfidence + 2 + // maxProbeInterval is the maximum interval between probes for an address + maxProbeInterval = 1 * time.Hour + // maxProbeResultTTL is the maximum time to keep probe results for an address + maxProbeResultTTL = maxRecentProbeResultWindow * maxProbeInterval ) // probeManager tracks reachability for a set of addresses. This struct decides the priority order of @@ -381,10 +381,10 @@ const ( // Probes returned by `GetProbe` should be marked as in progress using `MarkProbeInProgress` // before being executed. type probeManager struct { - now func() time.Time - recentProbeResultWindow int - ProbeInterval time.Duration - MaxProbesPerAddrsPerInterval int + now func() time.Time + recentProbeResultWindow int + ProbeInterval time.Duration + MaxDialsPerAddrsPerInterval int mx sync.Mutex inProgressProbes map[string]int // addr -> count @@ -393,12 +393,14 @@ type probeManager struct { addrs []ma.Multiaddr } -func newProbeManager(now func() time.Time, recentProbeResultWindow int) *probeManager { +func newProbeManager(now func() time.Time, recentProbeResultWindow int, probeInterval time.Duration, maxProbesPerAddrsPerInterval int) *probeManager { return &probeManager{ - statuses: make(map[string]*addrStatus), - inProgressProbes: make(map[string]int), - now: now, - recentProbeResultWindow: recentProbeResultWindow, + statuses: make(map[string]*addrStatus), + inProgressProbes: make(map[string]int), + now: now, + recentProbeResultWindow: recentProbeResultWindow, + ProbeInterval: probeInterval, + MaxDialsPerAddrsPerInterval: maxProbesPerAddrsPerInterval, } } @@ -408,7 +410,9 @@ func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiadd defer m.mx.Unlock() for _, a := range m.addrs { - switch m.statuses[string(a.Bytes())].outcomes.Reachability() { + s := m.statuses[string(a.Bytes())] + s.outcomes.RemoveBefore(m.now().Add(-maxProbeResultTTL)) // cleanup stale results + switch s.outcomes.Reachability() { case network.ReachabilityPublic: reachable = append(reachable, a) case network.ReachabilityPrivate: @@ -423,11 +427,12 @@ func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) { m.mx.Lock() defer m.mx.Unlock() - slices.SortStableFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) for _, addr := range addrs { - if _, ok := m.statuses[string(addr.Bytes())]; !ok { - m.statuses[string(addr.Bytes())] = &addrStatus{Addr: addr} + k := string(addr.Bytes()) + if _, ok := m.statuses[k]; !ok { + m.statuses[k] = &addrStatus{Addr: addr, outcomes: addrOutcomes{}} } } for k, s := range m.statuses { @@ -446,23 +451,36 @@ func (m *probeManager) GetProbe() []autonatv2.Request { m.mx.Lock() defer m.mx.Unlock() - reqs := make([]autonatv2.Request, 0, maxAddrsPerRequest) now := m.now() - for _, a := range m.addrs { - akey := string(a.Bytes()) - pc := m.probeCount(m.statuses[akey], now) + for i, a := range m.addrs { + ab := a.Bytes() + pc := m.requiredProbes(m.statuses[string(ab)], now) if pc == 0 { continue } - if len(reqs) == 0 && m.inProgressProbes[akey] >= pc { + if m.inProgressProbes[string(ab)] >= pc { continue } + reqs := make([]autonatv2.Request, 0, maxAddrsPerRequest) reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true}) - if len(reqs) >= maxAddrsPerRequest { - break + // We have the first(primary) address. Append other addresses, ignoring inprogress probes + // on secondary addresses. The expectation is that the primary address will + // be dialed. + for j := 1; j < len(m.addrs); j++ { + k := (i + j) % len(m.addrs) + ab := m.addrs[k].Bytes() + pc := m.requiredProbes(m.statuses[string(ab)], now) + if pc == 0 { + continue + } + reqs = append(reqs, autonatv2.Request{Addr: m.addrs[k], SendDialData: true}) + if len(reqs) >= maxAddrsPerRequest { + break + } } + return reqs } - return reqs + return nil } // MarkProbeInProgress should be called when a probe is started. @@ -499,10 +517,10 @@ func (m *probeManager) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Res // decrement in-progress count for the first address primaryAddrKey := string(reqs[0].Addr.Bytes()) m.inProgressProbes[primaryAddrKey]-- - m.inProgressProbesTotal-- if m.inProgressProbes[primaryAddrKey] <= 0 { delete(m.inProgressProbes, primaryAddrKey) } + m.inProgressProbesTotal-- // nothing to do if the request errored. if err != nil { @@ -510,29 +528,43 @@ func (m *probeManager) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Res } expireBefore := now.Add(-maxProbeInterval) - // request failed if res.AllAddrsRefused { if s, ok := m.statuses[primaryAddrKey]; ok { - s.lastRefusalTime = now - s.consecutiveRefusals++ + m.addRefusal(s, now, expireBefore) } return } - // mark only the primary status as refused - if s, ok := m.statuses[primaryAddrKey]; ok { - s.lastRefusalTime = now - s.consecutiveRefusals++ + // Consider only primary address as refused. This increases the number of + // probes are refused, but refused probes are cheap as no dial is + // made by the server. + dialAddrKey := string(res.Addr.Bytes()) + if dialAddrKey != primaryAddrKey { + if s, ok := m.statuses[primaryAddrKey]; ok { + m.addRefusal(s, now, expireBefore) + } } // record the result for the probed address - if s, ok := m.statuses[string(res.Addr.Bytes())]; ok { - s.outcomes.AddDialOutcomeAndExpire(now, res.Reachability, m.recentProbeResultWindow, expireBefore) - s.probeTimes = append(s.probeTimes, now) + if s, ok := m.statuses[dialAddrKey]; ok { + m.addDial(s, now, res.Reachability, expireBefore) } } -func (m *probeManager) probeCount(s *addrStatus, now time.Time) int { +func (*probeManager) addRefusal(s *addrStatus, now time.Time, expireBefore time.Time) { + s.lastRefusalTime = now + s.consecutiveRefusals++ +} + +func (m *probeManager) addDial(s *addrStatus, now time.Time, rch network.Reachability, expireBefore time.Time) { + s.lastRefusalTime = time.Time{} + s.consecutiveRefusals = 0 + s.dialTimes = append(s.dialTimes, now) + s.outcomes.AddOutcome(now, rch, m.recentProbeResultWindow) + s.outcomes.RemoveBefore(expireBefore) +} + +func (m *probeManager) requiredProbes(s *addrStatus, now time.Time) int { if s.consecutiveRefusals >= maxConsecutiveRefusals { if now.Sub(s.lastRefusalTime) < addrRefusedProbeInterval { return 0 @@ -543,16 +575,16 @@ func (m *probeManager) probeCount(s *addrStatus, now time.Time) int { } // Don't probe if we have probed too many times recently - if m.recentProbesCount(s, now) >= m.MaxProbesPerAddrsPerInterval { + if m.recentDialCount(s, now) >= m.MaxDialsPerAddrsPerInterval { return 0 } - return s.outcomes.ProbeCount(now) + return s.outcomes.RequiredProbes(now) } -func (m *probeManager) recentProbesCount(s *addrStatus, now time.Time) int { +func (m *probeManager) recentDialCount(s *addrStatus, now time.Time) int { cnt := 0 - for _, t := range slices.Backward(s.probeTimes) { + for _, t := range slices.Backward(s.dialTimes) { if now.Sub(t) > m.ProbeInterval { break } @@ -570,8 +602,8 @@ type addrStatus struct { Addr ma.Multiaddr lastRefusalTime time.Time consecutiveRefusals int - probeTimes []time.Time - outcomes *addrOutcomes // TODO: no pointer? + dialTimes []time.Time + outcomes addrOutcomes } type addrOutcomes struct { @@ -579,33 +611,12 @@ type addrOutcomes struct { } func (o *addrOutcomes) Reachability() network.Reachability { - successes, failures := o.outcomeCounts() - return o.computeReachability(successes, failures) + rch, _, _ := o.reachabilityAndCounts() + return rch } -func (*addrOutcomes) computeReachability(success, failures int) network.Reachability { - if success-failures >= minConfidence { - return network.ReachabilityPublic - } - if failures-success >= minConfidence { - return network.ReachabilityPrivate - } - return network.ReachabilityUnknown -} - -func (o *addrOutcomes) numProbesInInterval(now time.Time, probeInterval time.Duration) int { - cnt := 0 - for _, v := range slices.Backward(o.outcomes) { - if now.Sub(v.At) > probeInterval { - break - } - cnt++ - } - return cnt -} - -func (o *addrOutcomes) ProbeCount(now time.Time) int { - successes, failures := o.outcomeCounts() +func (o *addrOutcomes) RequiredProbes(now time.Time) int { + reachability, successes, failures := o.reachabilityAndCounts() confidence := successes - failures if confidence < 0 { confidence = -confidence @@ -614,10 +625,9 @@ func (o *addrOutcomes) ProbeCount(now time.Time) int { if cnt > 0 { return cnt } - // we have enough confirmations, see if we should still retest + // we have enough confirmations; check if we should refresh - // There are no confirmations. This should never happen. In case there are no confirmations, - // the confidence logic above should require a few probes. + // Should never happen. The confidence logic above should require a few probes. if len(o.outcomes) == 0 { return 0 } @@ -627,7 +637,7 @@ func (o *addrOutcomes) ProbeCount(now time.Time) int { return 1 } // if the last probe result was different from reachability, probe again. - switch o.computeReachability(successes, failures) { + switch reachability { case network.ReachabilityPublic: if !lastOutcome.Success { return 1 @@ -637,13 +647,13 @@ func (o *addrOutcomes) ProbeCount(now time.Time) int { return 1 } default: - // this should never happen. no reachability => confidence is low + // this should never happen return 1 } return 0 } -func (o *addrOutcomes) AddDialOutcomeAndExpire(at time.Time, rch network.Reachability, windowSize int, expireBefore time.Time) { +func (o *addrOutcomes) AddOutcome(at time.Time, rch network.Reachability, windowSize int) { success := false switch rch { case network.ReachabilityPublic: @@ -657,22 +667,20 @@ func (o *addrOutcomes) AddDialOutcomeAndExpire(at time.Time, rch network.Reachab if len(o.outcomes) > windowSize { o.outcomes = slices.Delete(o.outcomes, 0, len(o.outcomes)-windowSize) } - o.removeBefore(expireBefore) } -// removeBefore removes outcomes before t -func (o *addrOutcomes) removeBefore(t time.Time) { - st := 0 - var ot dialOutcome - for st, ot = range o.outcomes { - if ot.At.After(t) { +// RemoveBefore removes outcomes before t +func (o *addrOutcomes) RemoveBefore(t time.Time) { + var end = 0 + for ; end < len(o.outcomes); end++ { + if !o.outcomes[end].At.Before(t) { break } } - o.outcomes = slices.Delete(o.outcomes, 0, st) + o.outcomes = slices.Delete(o.outcomes, 0, end) } -func (o *addrOutcomes) outcomeCounts() (successes, failures int) { +func (o *addrOutcomes) reachabilityAndCounts() (rch network.Reachability, successes int, failures int) { for _, r := range o.outcomes { if r.Success { successes++ @@ -680,5 +688,11 @@ func (o *addrOutcomes) outcomeCounts() (successes, failures int) { failures++ } } - return successes, failures + if successes-failures >= minConfidence { + return network.ReachabilityPublic, successes, failures + } + if failures-successes >= minConfidence { + return network.ReachabilityPrivate, successes, failures + } + return network.ReachabilityUnknown, successes, failures } diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index dcaf033b62..7bc6576bc7 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -12,245 +12,183 @@ import ( "github.com/benbjohnson/clock" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" - "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestAddrTrackerGetProbe(t *testing.T) { +func TestProbeManager(t *testing.T) { pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1") + pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1") cl := clock.NewMock() - t.Run("inprogress probes", func(t *testing.T) { - tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) + nextProbe := func(pm *probeManager) []autonatv2.Request { + reqs := pm.GetProbe() + if len(reqs) != 0 { + pm.MarkProbeInProgress(reqs) + } + return reqs + } + + makeNewProbeManager := func(addrs []ma.Multiaddr) *probeManager { + pm := newProbeManager(cl.Now, maxRecentProbeResultWindow, 10*time.Minute, 10) + pm.UpdateAddrs(addrs) + return pm + } - tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) - reqs1 := tr.GetProbe() - reqs2 := tr.GetProbe() + t.Run("addrs updates", func(t *testing.T) { + pm := newProbeManager(cl.Now, maxRecentProbeResultWindow, 10*time.Minute, 10) + pm.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) + for { + reqs := nextProbe(pm) + if len(reqs) == 0 { + break + } + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } + reachable, _ := pm.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1, pub2}) + pm.UpdateAddrs([]ma.Multiaddr{pub3}) + + reachable, _ = pm.AppendConfirmedAddrs(nil, nil) + require.Empty(t, reachable) + require.Len(t, pm.statuses, 1) + }) + + t.Run("inprogress", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + reqs1 := pm.GetProbe() + reqs2 := pm.GetProbe() require.Equal(t, reqs1, reqs2) - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() - require.NotEmpty(t, reqs) - tr.MarkProbeInProgress(reqs) + for i := 0; i < targetConfidence; i++ { + reqs := nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) } - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() - require.NotEmpty(t, reqs) - tr.MarkProbeInProgress(reqs) - require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) - } - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() - require.Empty(t, reqs) + for i := 0; i < targetConfidence; i++ { + reqs := nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}}) } + reqs := pm.GetProbe() + require.Empty(t, reqs) }) - t.Run("probe refusals", func(t *testing.T) { - tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) - tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) + t.Run("refusals", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) var probes [][]autonatv2.Request - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() + for i := 0; i < targetConfidence; i++ { + reqs := nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) - tr.MarkProbeInProgress(reqs) probes = append(probes, reqs) } - // first one rejected second one successful + // first one refused second one successful for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_OK}, nil) + pm.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Idx: 1, Reachability: network.ReachabilityPublic}, nil) } // the second address is validated! probes = nil - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() + for i := 0; i < targetConfidence; i++ { + reqs := nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) - tr.MarkProbeInProgress(reqs) probes = append(probes, reqs) } - reqs := tr.GetProbe() + reqs := pm.GetProbe() require.Empty(t, reqs) for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{}, autonatv2.ErrDialRefused) + pm.CompleteProbe(probes[i], autonatv2.Result{AllAddrsRefused: true}, nil) } - // all requests refused - reqs = tr.GetProbe() + // all requests refused; no more probes for too many refusals + reqs = pm.GetProbe() require.Empty(t, reqs) cl.Add(10*time.Minute + 5*time.Second) - reqs = tr.GetProbe() + reqs = pm.GetProbe() require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) }) - t.Run("probe successes", func(t *testing.T) { - tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) - tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) - var probes [][]autonatv2.Request - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() - require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) - tr.MarkProbeInProgress(reqs) - probes = append(probes, reqs) - } - // first one rejected second one successful - for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil) - } - // the second address is validated! - probes = nil - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() - require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) - tr.MarkProbeInProgress(reqs) - probes = append(probes, reqs) - } - reqs := tr.GetProbe() - require.Empty(t, reqs) - for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_OK}, nil) + t.Run("successes", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + for j := 0; j < 2; j++ { + for i := 0; i < targetConfidence; i++ { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } } - // all statueses probed - reqs = tr.GetProbe() + // all addrs confirmed + reqs := pm.GetProbe() require.Empty(t, reqs) cl.Add(1*time.Hour + 5*time.Second) - reqs = tr.GetProbe() + reqs = nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) - tr.MarkProbeInProgress(reqs) - reqs = tr.GetProbe() - require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) + reqs = nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}}) }) - t.Run("reachabilityUpdate", func(t *testing.T) { - tr := newProbeManager(cl.Now, maxRecentProbeResultWindow) - tr.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) - var probes [][]autonatv2.Request - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() - require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) - tr.MarkProbeInProgress(reqs) - probes = append(probes, reqs) + t.Run("throttling on indeterminate reachability", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + reachability := network.ReachabilityPublic + nextReachability := func() network.Reachability { + if reachability == network.ReachabilityPublic { + reachability = network.ReachabilityPrivate + } else { + reachability = network.ReachabilityPublic + } + return reachability } - for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil) + for range 2 * 10 { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil) } - probes = nil - for i := 0; i < 3; i++ { - reqs := tr.GetProbe() - require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}}) - tr.MarkProbeInProgress(reqs) - probes = append(probes, reqs) + reqs := pm.GetProbe() + require.Empty(t, reqs) + + cl.Add(10*time.Minute + 5*time.Second) + reqs = pm.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + for range 2 * 10 { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil) } - for i := 0; i < len(probes); i++ { - tr.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil) + reqs = pm.GetProbe() + require.Empty(t, reqs) + + }) + + t.Run("reachabilityUpdate", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + for range 2 * targetConfidence { + reqs := nextProbe(pm) + if reqs[0].Addr.Equal(pub1) { + pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } else { + pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub2, Idx: 0, Reachability: network.ReachabilityPrivate}, nil) + } } - reachable, unreachable := tr.AppendConfirmedAddrs(nil, nil) + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) require.Equal(t, reachable, []ma.Multiaddr{pub1}) require.Equal(t, unreachable, []ma.Multiaddr{pub2}) + }) + t.Run("expiry", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1}) + for range 2 * targetConfidence { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } + + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) - // should expire addrs after 3 hours - cl.Add(3*time.Hour + 1*time.Second) - reachable, unreachable = tr.AppendConfirmedAddrs(nil, nil) + cl.Add(maxProbeResultTTL + 1*time.Second) + reachable, unreachable = pm.AppendConfirmedAddrs(nil, nil) require.Empty(t, reachable) require.Empty(t, unreachable) }) } -func TestAddrStatus(t *testing.T) { - now := time.Now() - probeResultWindow := maxRecentProbeResultWindow - - type input struct { - At time.Time - Success, Refused bool - } - type testCase struct { - inputs []input - probeCount int - reachability network.Reachability - } - tests := []testCase{ - { - inputs: []input{ - {At: now, Success: true}, - }, - probeCount: 2, - reachability: network.ReachabilityUnknown, - }, - { - inputs: []input{ - {At: now, Success: false}, - {At: now, Success: true}, - {At: now, Success: true}, - {At: now, Success: true}, - }, - probeCount: 1, - reachability: network.ReachabilityPublic, - }, - { - inputs: []input{ - {At: now, Success: true}, - {At: now, Success: false}, - {At: now, Success: false}, - {At: now, Success: false}, - }, - probeCount: 1, - reachability: network.ReachabilityPrivate, - }, - { - inputs: []input{ - {At: now, Success: false}, - {At: now, Success: false}, - {At: now, Success: false}, - {At: now, Success: false}, - {At: now, Success: false}, - {At: now, Success: false}, - {At: now, Success: true}, - {At: now, Success: true}, - {At: now, Success: true}, - {At: now, Success: true}, - }, - probeCount: 0, - reachability: network.ReachabilityPublic, - }, - } - for i, tt := range tests { - t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { - s := &addrStatus{Addr: ma.StringCast("/ip4/1.1.1.1/tcp/1")} - for _, inp := range tt.inputs { - if inp.Refused { - s.AddRefusal(now) - } else { - s.AddResult(now, inp.Success) - } - s.Trim(probeResultWindow) - } - require.Equal(t, tt.reachability, s.Reachability()) - require.Equal(t, tt.probeCount, s.ProbeCount(now)) - }) - } -} - -func TestAddrStatusRefused(t *testing.T) { - s := &addrStatus{Addr: ma.StringCast("/ip4/1.1.1.1/tcp/1")} - now := time.Now() - for i := 0; i < maxConsecutiveRefusals-1; i++ { - s.AddRefusal(now) - } - require.Equal(t, s.ProbeCount(now), 3) - s.AddRefusal(now) - require.Equal(t, s.ProbeCount(now), 0) - require.Equal(t, s.ProbeCount(now.Add(addrRefusedProbeInterval+(1*time.Nanosecond))), 1) // +1 to push it over the threshold - - s.AddResult(now, true) - require.Equal(t, s.ProbeCount(now), 2) - require.Equal(t, s.consecutiveRefusals.Count, 0) -} - type mockAutoNATClient struct { F func(context.Context, []autonatv2.Request) (autonatv2.Result, error) } @@ -261,11 +199,11 @@ func (m mockAutoNATClient) GetReachability(ctx context.Context, reqs []autonatv2 var _ autonatv2Client = mockAutoNATClient{} -func TestAddrReachabilityTracker(t *testing.T) { - pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") - pub2, _ := ma.NewMultiaddr("/ip4/1.1.1.2/tcp/1") - pub3, _ := ma.NewMultiaddr("/ip4/1.1.1.3/tcp/1") - pri, _ := ma.NewMultiaddr("/ip4/192.168.1.1/tcp/1") +func TestAddrsReachabilityTracker(t *testing.T) { + pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1") + pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1") + pri := ma.StringCast("/ip4/192.168.1.1/tcp/1") newTracker := func(cli mockAutoNATClient, cl clock.Clock) *addrsReachabilityTracker { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -280,7 +218,7 @@ func TestAddrReachabilityTracker(t *testing.T) { reachabilityUpdateCh: make(chan struct{}, 1), maxConcurrency: 3, newAddrsProbeDelay: 0 * time.Second, - addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow), + addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow, defaultResetInterval, 10), clock: cl, } err := tr.Start() @@ -293,20 +231,21 @@ func TestAddrReachabilityTracker(t *testing.T) { } t.Run("simple", func(t *testing.T) { + // pub1 reachable, pub2 unreachable, pub3 ignored mockClient := mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { - for _, req := range reqs { + for i, req := range reqs { if req.Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil } else if req.Addr.Equal(pub2) { - return autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil + return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil } } - return autonatv2.Result{}, autonatv2.ErrDialRefused + return autonatv2.Result{}, autonatv2.ErrNoPeers }, } tr := newTracker(mockClient, nil) - tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pri}) + tr.UpdateAddrs([]ma.Multiaddr{pub2, pub1, pri}) select { case <-tr.reachabilityUpdateCh: case <-time.After(2 * time.Second): @@ -314,27 +253,29 @@ func TestAddrReachabilityTracker(t *testing.T) { } reachable, unreachable := tr.ConfirmedAddrs() require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1) - require.Empty(t, unreachable) + require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2) - tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pub2, pri}) + tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pri}) select { case <-tr.reachabilityUpdateCh: - case <-time.After(1 * time.Second): - t.Fatal("unexpected call") + case <-time.After(2 * time.Second): + t.Fatal("expected reachability update") } reachable, unreachable = tr.ConfirmedAddrs() require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1) - require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2) + require.Empty(t, unreachable) }) t.Run("backoff", func(t *testing.T) { notify := make(chan struct{}, 1) - drainNotify := func() { + drainNotify := func() bool { + found := false for { select { case <-notify: + found = true default: - return + return found } } } @@ -350,9 +291,9 @@ func TestAddrReachabilityTracker(t *testing.T) { return autonatv2.Result{}, autonatv2.ErrNoPeers } if reqs[0].Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil } - return autonatv2.Result{}, autonatv2.ErrDialRefused + return autonatv2.Result{AllAddrsRefused: true}, nil }, } @@ -364,14 +305,8 @@ func TestAddrReachabilityTracker(t *testing.T) { // need to update clock after the background goroutine processes the new addrs time.Sleep(100 * time.Millisecond) cl.Add(1) - select { - case <-tr.reachabilityUpdateCh: - reachable, unreachable := tr.ConfirmedAddrs() - require.Empty(t, reachable) - require.Empty(t, unreachable) - case <-time.After(1 * time.Second): - t.Fatal("unexpected call") - } + time.Sleep(100 * time.Millisecond) + require.True(t, drainNotify()) // check that we did receive probes backoffInterval := backoffStartInterval for i := 0; i < 4; i++ { @@ -385,9 +320,9 @@ func TestAddrReachabilityTracker(t *testing.T) { cl.Add(backoffInterval/2 + 1) // +1 to push it slightly over the backoff interval backoffInterval *= 2 select { - case <-tr.reachabilityUpdateCh: + case <-notify: case <-time.After(1 * time.Second): - t.Fatal("unexpected call") + t.Fatal("expected probe") } reachable, unreachable := tr.ConfirmedAddrs() require.Empty(t, reachable) @@ -399,7 +334,7 @@ func TestAddrReachabilityTracker(t *testing.T) { select { case <-tr.reachabilityUpdateCh: case <-time.After(1 * time.Second): - t.Fatal("unexpected call") + t.Fatal("unexpected reachability update") } reachable, unreachable := tr.ConfirmedAddrs() require.Equal(t, reachable, []ma.Multiaddr{pub1}) @@ -413,16 +348,16 @@ func TestAddrReachabilityTracker(t *testing.T) { mockClient := mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { select { - case <-ctx.Done(): - return autonatv2.Result{}, ctx.Err() case called <- struct{}{}: notify <- struct{}{} + return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil + default: + return autonatv2.Result{AllAddrsRefused: true}, nil } - return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil }, } - tr := newTracker(mockClient, clock.New()) + tr := newTracker(mockClient, nil) tr.UpdateAddrs([]ma.Multiaddr{pub1}) for i := 0; i < minConfidence; i++ { select { @@ -433,16 +368,81 @@ func TestAddrReachabilityTracker(t *testing.T) { } select { case <-tr.reachabilityUpdateCh: - t.Fatal("didn't expect reachability update") + reachable, unreachable := tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) case <-time.After(1 * time.Second): + t.Fatal("expected reachability update") } - tr.UpdateAddrs([]ma.Multiaddr{pub1}) + tr.UpdateAddrs([]ma.Multiaddr{pub1}) // same addrs shouldn't get update + select { + case <-tr.reachabilityUpdateCh: + t.Fatal("didn't expect reachability update") + case <-time.After(100 * time.Millisecond): + } + tr.UpdateAddrs([]ma.Multiaddr{pub2}) select { case <-tr.reachabilityUpdateCh: + reachable, unreachable := tr.ConfirmedAddrs() + require.Empty(t, reachable) + require.Empty(t, unreachable) case <-time.After(1 * time.Second): t.Fatal("expected reachability update") } }) + + t.Run("refresh after reset interval", func(t *testing.T) { + notify := make(chan struct{}, 1) + drainNotify := func() bool { + found := false + for { + select { + case <-notify: + found = true + default: + return found + } + } + } + + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + select { + case notify <- struct{}{}: + default: + } + if reqs[0].Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil + } + return autonatv2.Result{AllAddrsRefused: true}, nil + }, + } + + cl := clock.NewMock() + tr := newTracker(mockClient, cl) + + // update addrs and wait for initial checks + tr.UpdateAddrs([]ma.Multiaddr{pub1}) + // need to update clock after the background goroutine processes the new addrs + time.Sleep(100 * time.Millisecond) + cl.Add(1) + time.Sleep(100 * time.Millisecond) + require.True(t, drainNotify()) // check that we did receive probes + + cl.Add(maxProbeInterval / 2) + select { + case <-notify: + t.Fatal("unexpected call") + case <-time.After(50 * time.Millisecond): + } + + cl.Add(maxProbeInterval/2 + defaultResetInterval) // defaultResetInterval for the next probe time + select { + case <-notify: + case <-time.After(1 * time.Second): + t.Fatal("expected probe") + } + }) } func TestRunProbes(t *testing.T) { @@ -457,7 +457,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) require.True(t, result) @@ -472,7 +472,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) @@ -490,7 +490,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) var wg sync.WaitGroup wg.Add(1) @@ -523,17 +523,17 @@ func TestRunProbes(t *testing.T) { t.Run("handles refusals", func(t *testing.T) { pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { - for _, req := range reqs { + for i, req := range reqs { if req.Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil } } - return autonatv2.Result{}, autonatv2.ErrDialRefused + return autonatv2.Result{AllAddrsRefused: true}, nil }, } @@ -547,20 +547,20 @@ func TestRunProbes(t *testing.T) { }) t.Run("handles completions", func(t *testing.T) { - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow) + addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { - for _, req := range reqs { + for i, req := range reqs { if req.Addr.Equal(pub1) { - return autonatv2.Result{Addr: pub1, DialStatus: pb.DialStatus_OK}, nil + return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil } if req.Addr.Equal(pub2) { - return autonatv2.Result{Addr: pub2, DialStatus: pb.DialStatus_E_DIAL_ERROR}, nil + return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil } } - return autonatv2.Result{}, autonatv2.ErrDialRefused + return autonatv2.Result{AllAddrsRefused: true}, nil }, } @@ -574,9 +574,72 @@ func TestRunProbes(t *testing.T) { }) } +func TestDialOutcome(t *testing.T) { + cases := []struct { + inputs string + wantRequiredProbes int + wantReachability network.Reachability + }{ + { + inputs: "SSSSSSSSSSS", + wantRequiredProbes: 0, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SSSSSSSSSSF", + wantRequiredProbes: 1, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SFSFSFSFSSSS", + wantRequiredProbes: 0, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SSSSSSSSSFSF", + wantRequiredProbes: 2, + wantReachability: network.ReachabilityUnknown, + }, + { + inputs: "S", + wantRequiredProbes: 2, + wantReachability: network.ReachabilityUnknown, + }, + { + inputs: "FF", + wantRequiredProbes: 1, + wantReachability: network.ReachabilityPrivate, + }, + } + for _, c := range cases { + t.Run(c.inputs, func(t *testing.T) { + now := time.Time{}.Add(1 * time.Second) + ao := addrOutcomes{} + for _, r := range c.inputs { + if r == 'S' { + ao.AddOutcome(now, network.ReachabilityPublic, 5) + } else { + ao.AddOutcome(now, network.ReachabilityPrivate, 5) + } + now = now.Add(1 * time.Second) + } + require.Equal(t, ao.RequiredProbes(now), c.wantRequiredProbes) + require.Equal(t, ao.Reachability(), c.wantReachability) + if c.wantRequiredProbes == 0 { + now = now.Add(1*time.Hour + 10*time.Microsecond) + require.Equal(t, ao.RequiredProbes(now), 1) + } + + now = now.Add(1 * time.Second) + ao.RemoveBefore(now) + require.Len(t, ao.outcomes, 0) + }) + } +} + func BenchmarkAddrTracker(b *testing.B) { cl := clock.NewMock() - t := newProbeManager(cl.Now, maxRecentProbeResultWindow) + t := newProbeManager(cl.Now, maxRecentProbeResultWindow, 10*time.Minute, 10) var addrs []ma.Multiaddr for i := 0; i < 20; i++ { @@ -592,6 +655,6 @@ func BenchmarkAddrTracker(b *testing.B) { pp = p } t.MarkProbeInProgress(pp) - t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, DialStatus: pb.DialStatus_OK}, nil) + t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil) } } diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 791c0f9f83..578ddccb1e 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -592,6 +592,12 @@ func TestAreAddrsConsistency(t *testing.T) { dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"), success: false, }, + { + name: "dns6", + localAddr: ma.StringCast("/dns6/lib.p2p/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/"), + success: false, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index f15cb9ecbe..fea217b1e3 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -302,7 +302,6 @@ func (ac *client) handleDialBack(s network.Stream) { } w := pbio.NewDelimitedWriter(s) res := pb.DialBackResponse{} - // TODO: Check what happens on sending empty if err := w.WriteMsg(&res); err != nil { log.Debugf("failed to write dialback response: %s", err) s.Reset() @@ -310,7 +309,6 @@ func (ac *client) handleDialBack(s network.Stream) { } func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool { - // TODO: Check this n times if len(connLocalAddr) == 0 || len(dialedAddr) == 0 { return false } @@ -322,32 +320,31 @@ func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) boo if len(localProtos) != len(externalProtos) { return false } - for i := 0; i < len(localProtos); i++ { + for i, lp := range localProtos { + ep := externalProtos[i] if i == 0 { - switch externalProtos[i].Code { + switch ep.Code { case ma.P_DNS, ma.P_DNSADDR: - if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 { + if lp.Code == ma.P_IP4 || lp.Code == ma.P_IP6 { continue } return false case ma.P_DNS4: - if localProtos[i].Code == ma.P_IP4 { + if lp.Code == ma.P_IP4 { continue } return false case ma.P_DNS6: - if localProtos[i].Code == ma.P_IP6 { + if lp.Code == ma.P_IP6 { continue } return false } - if localProtos[i].Code != externalProtos[i].Code { - return false - } - } else { - if localProtos[i].Code != externalProtos[i].Code { + if lp.Code != ep.Code { return false } + } else if lp.Code != ep.Code { + return false } } return true From 0ff385702ef89cd6a7e22f8d81e8d2b96434c03c Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 21 Apr 2025 18:01:29 +0530 Subject: [PATCH 09/15] inject autonat client --- config/config.go | 32 ++-- p2p/host/basic/addrs_reachability_tracker.go | 101 +++++----- .../basic/addrs_reachability_tracker_test.go | 180 +++++++++++++++--- p2p/host/basic/basic_host.go | 21 +- .../12df2ba2ca01d0a2 | 3 + p2p/protocol/autonatv2/autonat.go | 14 +- p2p/protocol/autonatv2/autonat_test.go | 104 +++++++++- p2p/protocol/autonatv2/client.go | 20 +- p2p/protocol/autonatv2/server.go | 6 +- .../testdata/fuzz/FuzzClient/90eb8c62717e4cbe | 3 + 10 files changed, 355 insertions(+), 129 deletions(-) create mode 100644 p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 create mode 100644 p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/90eb8c62717e4cbe diff --git a/config/config.go b/config/config.go index b5d42eed02..83bcb07598 100644 --- a/config/config.go +++ b/config/config.go @@ -33,6 +33,7 @@ import ( routed "github.com/libp2p/go-libp2p/p2p/host/routed" "github.com/libp2p/go-libp2p/p2p/net/swarm" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" @@ -413,15 +414,7 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { return fxopts, nil } -func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) { - var autonatv2Dialer host.Host - if cfg.EnableAutoNATv2 { - ah, err := cfg.makeAutoNATV2Host() - if err != nil { - return nil, err - } - autonatv2Dialer = ah - } +func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus, an *autonatv2.AutoNAT) (*bhost.BasicHost, error) { h, err := bhost.NewHost(swrm, &bhost.HostOpts{ EventBus: eventBus, ConnManager: cfg.ConnManager, @@ -437,8 +430,7 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B EnableMetrics: !cfg.DisableMetrics, PrometheusRegisterer: cfg.PrometheusRegisterer, DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery, - EnableAutoNATv2: cfg.EnableAutoNATv2, - AutoNATv2Dialer: autonatv2Dialer, + AutoNATv2: an, }) if err != nil { return nil, err @@ -517,6 +509,24 @@ func (cfg *Config) NewNode() (host.Host, error) { }) return sw, nil }), + fx.Provide(func() (*autonatv2.AutoNAT, error) { + if !cfg.EnableAutoNATv2 { + return nil, nil + } + ah, err := cfg.makeAutoNATV2Host() + if err != nil { + return nil, err + } + var mt autonatv2.MetricsTracer + if !cfg.DisableMetrics { + mt = autonatv2.NewMetricsTracer(cfg.PrometheusRegisterer) + } + autoNATv2, err := autonatv2.New(ah, autonatv2.WithMetricsTracer(mt)) + if err != nil { + return nil, fmt.Errorf("failed to create autonatv2: %w", err) + } + return autoNATv2, nil + }), fx.Provide(cfg.newBasicHost), fx.Provide(func(bh *bhost.BasicHost) identify.IDService { return bh.IDService() diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index 4023108c3a..7e091032a3 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -66,7 +66,7 @@ func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh ch cancel: cancel, cli: client, reachabilityUpdateCh: reachabilityUpdateCh, - addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow, defaultResetInterval, 10), + addrTracker: newProbeManager(cl.Now), newAddrsProbeDelay: newAddrsProbeDelay, maxConcurrency: defaultMaxConcurrency, newAddrs: make(chan []ma.Multiaddr, 1), @@ -150,11 +150,11 @@ func (r *addrsReachabilityTracker) background() error { <-task.RespCh // ignore backoff from cancelled task task = reachabilityTask{} } + r.updateTrackedAddrs(addrs) newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay) if nextProbeTime.Before(newAddrsNextTime) { nextProbeTime = newAddrsNextTime } - r.updateTrackedAddrs(addrs) case <-r.ctx.Done(): if task.RespCh != nil { task.Cancel() @@ -217,6 +217,7 @@ func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) { } // reachabilityTask is a task to refresh reachability. +// Waiting on the zero value blocks forever. type reachabilityTask struct { Cancel context.CancelFunc RespCh chan bool @@ -244,10 +245,9 @@ var errTooManyConsecutiveFailures = errors.New("too many consecutive failures") // errTooManyConsecutiveFailures in case of persistent failures from autonatv2 module. type errCountingClient struct { autonatv2Client - MaxConsecutiveErrors int - mx sync.Mutex - consecutiveErrors int - loggedPrivateAddrsError bool + MaxConsecutiveErrors int + mx sync.Mutex + consecutiveErrors int } func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { @@ -259,10 +259,8 @@ func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv if c.consecutiveErrors > c.MaxConsecutiveErrors { err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err) } - // This is hacky, but we do want to log this error - if !c.loggedPrivateAddrsError && errors.Is(err, autonatv2.ErrPrivateAddrs) { + if errors.Is(err, autonatv2.ErrPrivateAddrs) { log.Errorf("private IP addr in autonatv2 request: %s", err) - c.loggedPrivateAddrsError = true // log it only once. This should never happen } } else { c.consecutiveErrors = 0 @@ -291,7 +289,7 @@ func runProbes(ctx context.Context, concurrency int, addrsTracker *probeManager, jobsCh := make(chan []autonatv2.Request, 1) // close jobs to terminate the workers var wg sync.WaitGroup wg.Add(concurrency) - for i := 0; i < concurrency; i++ { + for range concurrency { go func() { defer wg.Done() for reqs := range jobsCh { @@ -347,13 +345,16 @@ func isErrorPersistent(err error) bool { } const ( - // addrRefusedProbeInterval is the interval to probe addresses that have been refused + // recentProbeInterval is the interval to probe addresses that have been refused // these are generally addresses with newer transports for which we don't have many peers // capable of dialing the transport - addrRefusedProbeInterval = 10 * time.Minute + recentProbeInterval = 10 * time.Minute // maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which - // we wait for `addrRefusedProbeInterval` before probing again + // we wait for `recentProbeInterval` before probing again maxConsecutiveRefusals = 5 + // maxRecentDialsPerAddr is the maximum number of dials on an address before we stop probing for the address. + // This is used to prevent infinite probing of an address whose status is indeterminate for any reason. + maxRecentDialsPerAddr = 10 // confidence is the absolute difference between the number of successes and failures for an address // targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval` // before probing again @@ -361,30 +362,26 @@ const ( // minConfidence is the confidence threshold for an address to be considered reachable or unreachable // confidence is the absolute difference between the number of successes and failures for an address minConfidence = 2 - // maxRecentProbeResultWindow is the maximum number of recent probe results to consider for a single address + // maxRecentDialsWindow is the maximum number of recent probe results to consider for a single address // // +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure // and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to // targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval. - maxRecentProbeResultWindow = targetConfidence + 2 + maxRecentDialsWindow = targetConfidence + 2 // maxProbeInterval is the maximum interval between probes for an address maxProbeInterval = 1 * time.Hour // maxProbeResultTTL is the maximum time to keep probe results for an address - maxProbeResultTTL = maxRecentProbeResultWindow * maxProbeInterval + maxProbeResultTTL = maxRecentDialsWindow * maxProbeInterval ) -// probeManager tracks reachability for a set of addresses. This struct decides the priority order of -// addresses for testing reachability. +// probeManager tracks reachability for a set of addresses by periodically probing reachability with autonatv2. +// A Probe is a list of addresses which can be tested for reachability with autonatv2. +// This struct decides the priority order of addresses for testing reachability, and throttles in case there have +// been too many probes for an address in the `ProbeInterval`. // // Use the `runProbes` function to execute the probes with an autonatv2 client. -// -// Probes returned by `GetProbe` should be marked as in progress using `MarkProbeInProgress` -// before being executed. type probeManager struct { - now func() time.Time - recentProbeResultWindow int - ProbeInterval time.Duration - MaxDialsPerAddrsPerInterval int + now func() time.Time mx sync.Mutex inProgressProbes map[string]int // addr -> count @@ -393,14 +390,12 @@ type probeManager struct { addrs []ma.Multiaddr } -func newProbeManager(now func() time.Time, recentProbeResultWindow int, probeInterval time.Duration, maxProbesPerAddrsPerInterval int) *probeManager { +// newProbeManager creates a new probe manager. +func newProbeManager(now func() time.Time) *probeManager { return &probeManager{ - statuses: make(map[string]*addrStatus), - inProgressProbes: make(map[string]int), - now: now, - recentProbeResultWindow: recentProbeResultWindow, - ProbeInterval: probeInterval, - MaxDialsPerAddrsPerInterval: maxProbesPerAddrsPerInterval, + statuses: make(map[string]*addrStatus), + inProgressProbes: make(map[string]int), + now: now, } } @@ -454,7 +449,7 @@ func (m *probeManager) GetProbe() []autonatv2.Request { now := m.now() for i, a := range m.addrs { ab := a.Bytes() - pc := m.requiredProbes(m.statuses[string(ab)], now) + pc := m.requiredProbeCount(m.statuses[string(ab)], now) if pc == 0 { continue } @@ -469,7 +464,7 @@ func (m *probeManager) GetProbe() []autonatv2.Request { for j := 1; j < len(m.addrs); j++ { k := (i + j) % len(m.addrs) ab := m.addrs[k].Bytes() - pc := m.requiredProbes(m.statuses[string(ab)], now) + pc := m.requiredProbeCount(m.statuses[string(ab)], now) if pc == 0 { continue } @@ -527,65 +522,65 @@ func (m *probeManager) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Res return } - expireBefore := now.Add(-maxProbeInterval) + // Consider only primary address as refused. This increases the number of + // probes are refused, but refused probes are cheap as no dial is + // made by the server. if res.AllAddrsRefused { if s, ok := m.statuses[primaryAddrKey]; ok { - m.addRefusal(s, now, expireBefore) + m.addRefusal(s, now) } return } - - // Consider only primary address as refused. This increases the number of - // probes are refused, but refused probes are cheap as no dial is - // made by the server. dialAddrKey := string(res.Addr.Bytes()) if dialAddrKey != primaryAddrKey { if s, ok := m.statuses[primaryAddrKey]; ok { - m.addRefusal(s, now, expireBefore) + m.addRefusal(s, now) } } - // record the result for the probed address + // record the result for the dialled address + expireBefore := now.Add(-maxProbeInterval) if s, ok := m.statuses[dialAddrKey]; ok { m.addDial(s, now, res.Reachability, expireBefore) } } -func (*probeManager) addRefusal(s *addrStatus, now time.Time, expireBefore time.Time) { +func (*probeManager) addRefusal(s *addrStatus, now time.Time) { s.lastRefusalTime = now s.consecutiveRefusals++ } -func (m *probeManager) addDial(s *addrStatus, now time.Time, rch network.Reachability, expireBefore time.Time) { +func (*probeManager) addDial(s *addrStatus, now time.Time, rch network.Reachability, expireBefore time.Time) { s.lastRefusalTime = time.Time{} s.consecutiveRefusals = 0 s.dialTimes = append(s.dialTimes, now) - s.outcomes.AddOutcome(now, rch, m.recentProbeResultWindow) + s.outcomes.AddOutcome(now, rch, maxRecentDialsWindow) s.outcomes.RemoveBefore(expireBefore) } -func (m *probeManager) requiredProbes(s *addrStatus, now time.Time) int { +func (m *probeManager) requiredProbeCount(s *addrStatus, now time.Time) int { if s.consecutiveRefusals >= maxConsecutiveRefusals { - if now.Sub(s.lastRefusalTime) < addrRefusedProbeInterval { + if now.Sub(s.lastRefusalTime) < recentProbeInterval { return 0 } - // reset this + // reset every `recentProbeInterval` s.lastRefusalTime = time.Time{} s.consecutiveRefusals = 0 } // Don't probe if we have probed too many times recently - if m.recentDialCount(s, now) >= m.MaxDialsPerAddrsPerInterval { + rd := m.recentDialCount(s, now) + if rd >= maxRecentDialsPerAddr { return 0 } - return s.outcomes.RequiredProbes(now) + return s.outcomes.RequiredProbeCount(now) } -func (m *probeManager) recentDialCount(s *addrStatus, now time.Time) int { +func (*probeManager) recentDialCount(s *addrStatus, now time.Time) int { cnt := 0 for _, t := range slices.Backward(s.dialTimes) { - if now.Sub(t) > m.ProbeInterval { + if now.Sub(t) > recentProbeInterval { break } cnt++ @@ -615,7 +610,7 @@ func (o *addrOutcomes) Reachability() network.Reachability { return rch } -func (o *addrOutcomes) RequiredProbes(now time.Time) int { +func (o *addrOutcomes) RequiredProbeCount(now time.Time) int { reachability, successes, failures := o.reachabilityAndCounts() confidence := successes - failures if confidence < 0 { diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index 7bc6576bc7..a6ad9377a9 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -4,6 +4,9 @@ import ( "context" "errors" "fmt" + "math/rand" + "net/netip" + "strings" "sync" "sync/atomic" "testing" @@ -33,13 +36,13 @@ func TestProbeManager(t *testing.T) { } makeNewProbeManager := func(addrs []ma.Multiaddr) *probeManager { - pm := newProbeManager(cl.Now, maxRecentProbeResultWindow, 10*time.Minute, 10) + pm := newProbeManager(cl.Now) pm.UpdateAddrs(addrs) return pm } t.Run("addrs updates", func(t *testing.T) { - pm := newProbeManager(cl.Now, maxRecentProbeResultWindow, 10*time.Minute, 10) + pm := newProbeManager(cl.Now) pm.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) for { reqs := nextProbe(pm) @@ -62,11 +65,11 @@ func TestProbeManager(t *testing.T) { reqs1 := pm.GetProbe() reqs2 := pm.GetProbe() require.Equal(t, reqs1, reqs2) - for i := 0; i < targetConfidence; i++ { + for range targetConfidence { reqs := nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) } - for i := 0; i < targetConfidence; i++ { + for range targetConfidence { reqs := nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}}) } @@ -77,32 +80,32 @@ func TestProbeManager(t *testing.T) { t.Run("refusals", func(t *testing.T) { pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) var probes [][]autonatv2.Request - for i := 0; i < targetConfidence; i++ { + for range targetConfidence { reqs := nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) probes = append(probes, reqs) } // first one refused second one successful - for i := 0; i < len(probes); i++ { - pm.CompleteProbe(probes[i], autonatv2.Result{Addr: pub2, Idx: 1, Reachability: network.ReachabilityPublic}, nil) + for _, p := range probes { + pm.CompleteProbe(p, autonatv2.Result{Addr: pub2, Idx: 1, Reachability: network.ReachabilityPublic}, nil) } // the second address is validated! probes = nil - for i := 0; i < targetConfidence; i++ { + for range targetConfidence { reqs := nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) probes = append(probes, reqs) } reqs := pm.GetProbe() require.Empty(t, reqs) - for i := 0; i < len(probes); i++ { - pm.CompleteProbe(probes[i], autonatv2.Result{AllAddrsRefused: true}, nil) + for _, p := range probes { + pm.CompleteProbe(p, autonatv2.Result{AllAddrsRefused: true}, nil) } // all requests refused; no more probes for too many refusals reqs = pm.GetProbe() require.Empty(t, reqs) - cl.Add(10*time.Minute + 5*time.Second) + cl.Add(recentProbeInterval) reqs = pm.GetProbe() require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) }) @@ -119,7 +122,7 @@ func TestProbeManager(t *testing.T) { reqs := pm.GetProbe() require.Empty(t, reqs) - cl.Add(1*time.Hour + 5*time.Second) + cl.Add(maxProbeInterval + time.Millisecond) reqs = nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) reqs = nextProbe(pm) @@ -137,17 +140,18 @@ func TestProbeManager(t *testing.T) { } return reachability } - for range 2 * 10 { + // both addresses are indeterminate + for range 2 * maxRecentDialsPerAddr { reqs := nextProbe(pm) pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil) } reqs := pm.GetProbe() require.Empty(t, reqs) - cl.Add(10*time.Minute + 5*time.Second) + cl.Add(recentProbeInterval + time.Millisecond) reqs = pm.GetProbe() require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) - for range 2 * 10 { + for range 2 * maxRecentDialsPerAddr { reqs := nextProbe(pm) pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil) } @@ -218,7 +222,7 @@ func TestAddrsReachabilityTracker(t *testing.T) { reachabilityUpdateCh: make(chan struct{}, 1), maxConcurrency: 3, newAddrsProbeDelay: 0 * time.Second, - addrTracker: newProbeManager(cl.Now, maxRecentProbeResultWindow, defaultResetInterval, 10), + addrTracker: newProbeManager(cl.Now), clock: cl, } err := tr.Start() @@ -457,7 +461,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) + addrTracker := newProbeManager(time.Now) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) require.True(t, result) @@ -472,7 +476,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) + addrTracker := newProbeManager(time.Now) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) @@ -490,7 +494,7 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) + addrTracker := newProbeManager(time.Now) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) var wg sync.WaitGroup wg.Add(1) @@ -523,7 +527,7 @@ func TestRunProbes(t *testing.T) { t.Run("handles refusals", func(t *testing.T) { pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) + addrTracker := newProbeManager(time.Now) addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ @@ -547,7 +551,7 @@ func TestRunProbes(t *testing.T) { }) t.Run("handles completions", func(t *testing.T) { - addrTracker := newProbeManager(time.Now, maxRecentProbeResultWindow, defaultResetInterval, 10) + addrTracker := newProbeManager(time.Now) addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ @@ -623,11 +627,11 @@ func TestDialOutcome(t *testing.T) { } now = now.Add(1 * time.Second) } - require.Equal(t, ao.RequiredProbes(now), c.wantRequiredProbes) + require.Equal(t, ao.RequiredProbeCount(now), c.wantRequiredProbes) require.Equal(t, ao.Reachability(), c.wantReachability) if c.wantRequiredProbes == 0 { - now = now.Add(1*time.Hour + 10*time.Microsecond) - require.Equal(t, ao.RequiredProbes(now), 1) + now = now.Add(maxProbeInterval + 10*time.Microsecond) + require.Equal(t, ao.RequiredProbeCount(now), 1) } now = now.Add(1 * time.Second) @@ -639,11 +643,11 @@ func TestDialOutcome(t *testing.T) { func BenchmarkAddrTracker(b *testing.B) { cl := clock.NewMock() - t := newProbeManager(cl.Now, maxRecentProbeResultWindow, 10*time.Minute, 10) + t := newProbeManager(cl.Now) var addrs []ma.Multiaddr - for i := 0; i < 20; i++ { - addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", i))) + for range 20 { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", rand.Intn(1000)))) } t.UpdateAddrs(addrs) b.ReportAllocs() @@ -658,3 +662,125 @@ func BenchmarkAddrTracker(b *testing.B) { t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil) } } + +func FuzzAddrsReachabilityTracker(f *testing.F) { + cl := clock.NewMock() + // The only constraint we force is that result.Idx < len(reqs) + client := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + switch rand.Intn(7) { + case 0: + return autonatv2.Result{AllAddrsRefused: true}, nil + case 1: + return autonatv2.Result{}, errors.New("test error") + case 2: + return autonatv2.Result{}, nil + case 3: + k := rand.Intn(len(reqs)) + r := network.Reachability(rand.Intn(3)) + return autonatv2.Result{Addr: reqs[k].Addr, Idx: k, Reachability: r}, nil + case 4: + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic, AllAddrsRefused: true}, nil + case 5: + return autonatv2.Result{Addr: reqs[0].Addr, Idx: len(reqs) - 1, Reachability: network.ReachabilityPublic, AllAddrsRefused: true}, nil + default: + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil + } + }, + } + + randProto := func() ma.Multiaddr { + protoTemplates := []string{ + "/tcp/%d/", + "/udp/%d/", + "/udp/%d/quic-v1/", + "/udp/%d/quic-v1/tcp/%d", + "/udp/%d/quic-v1/webtransport/", + "/udp/%d/webrtc/", + "/udp/%d/webrtc-direct/", + "/unix/hello/", + } + s := protoTemplates[rand.Intn(len(protoTemplates))] + if strings.Count(s, "%d") == 1 { + return ma.StringCast(fmt.Sprintf(s, rand.Intn(1000))) + } + return ma.StringCast(fmt.Sprintf(s, rand.Intn(1000), rand.Intn(1000))) + } + + randIP := func() ma.Multiaddr { + x := rand.Intn(2) + if x == 0 { + i := rand.Int31() + ip := netip.AddrFrom4([4]byte{byte(i), byte(i >> 8), byte(i >> 16), byte(i >> 24)}) + return ma.StringCast(fmt.Sprintf("/ip4/%s/tcp/1", ip)) + } + a, b := rand.Int63(), rand.Int63() + ip := netip.AddrFrom16([16]byte{ + byte(a), byte(a >> 8), byte(a >> 16), byte(a >> 24), + byte(a >> 32), byte(a >> 40), byte(a >> 48), byte(a >> 56), + byte(b), byte(b >> 8), byte(b >> 16), byte(b >> 24), + byte(b >> 32), byte(b >> 40), byte(b >> 48), byte(b >> 56), + }) + return ma.StringCast(fmt.Sprintf("/ip6/%s/tcp/1", ip)) + } + + newAddrs := func() ma.Multiaddr { + switch rand.Intn(5) { + case 0: + return randIP().Encapsulate(randProto()) + case 1: + return randProto() + case 2: + return nil + default: + return randProto().Encapsulate(randIP()) + } + } + + randDNSAddr := func(hostName string) ma.Multiaddr { + var da ma.Multiaddr + switch rand.Intn(4) { + case 0: + da = ma.StringCast(fmt.Sprintf("/dns/%s/", hostName)) + case 1: + da = ma.StringCast(fmt.Sprintf("/dns4/%s/", hostName)) + case 2: + da = ma.StringCast(fmt.Sprintf("/dns6/%s/", hostName)) + default: + da = ma.StringCast(fmt.Sprintf("/dnsaddr/%s/", hostName)) + } + return da.Encapsulate(randProto()) + } + + getAddrs := func(numAddrs int, hostNames []byte) []ma.Multiaddr { + const maxAddrs = 1000 + numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs + addrs := make([]ma.Multiaddr, numAddrs) + for i := range numAddrs { + addrs[i] = newAddrs() + } + maxDNSAddrs := 10 + for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 { + ed := min(i+2, len(hostNames)) + addrs = append(addrs, randDNSAddr(string(hostNames[i:ed]))) + } + return addrs + } + + f.Fuzz(func(t *testing.T, i int, hostNames []byte) { + tr := newAddrsReachabilityTracker(client, nil, cl) + + require.NoError(t, tr.Start()) + tr.UpdateAddrs(getAddrs(i, hostNames)) + + // fuzz tests need to finish in 10 seconds for some reason + // https://github.com/golang/go/issues/48157 + // https://github.com/golang/go/commit/5d24203c394e6b64c42a9f69b990d94cb6c8aad4#diff-4e3b9481b8794eb058998e2bec389d3db7a23c54e67ac0f7259a3a5d2c79fd04R474-R483 + const maxIters = 20 + for range maxIters { + cl.Add(5 * time.Minute) + time.Sleep(100 * time.Millisecond) + } + require.NoError(t, tr.Close()) + }) +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index a42ad9b2eb..d7b9f40ab2 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -156,8 +156,8 @@ type HostOpts struct { // DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify DisableIdentifyAddressDiscovery bool - EnableAutoNATv2 bool - AutoNATv2Dialer host.Host + + AutoNATv2 *autonatv2.AutoNAT } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. @@ -237,20 +237,11 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { tfl = s.TransportForListening } - if opts.EnableAutoNATv2 { - var mt autonatv2.MetricsTracer - if opts.EnableMetrics { - mt = autonatv2.NewMetricsTracer(opts.PrometheusRegisterer) - } - // keep this on host as it has the server as well as the client - h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer, autonatv2.WithMetricsTracer(mt)) - if err != nil { - return nil, fmt.Errorf("failed to create autonatv2: %w", err) - } + if opts.AutoNATv2 != nil { + h.autonatv2 = opts.AutoNATv2 } - // avoid typed nil errors - var autonatv2Client autonatv2Client + var autonatv2Client autonatv2Client // avoid typed nil errors if h.autonatv2 != nil { autonatv2Client = h.autonatv2 } @@ -327,7 +318,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { func (h *BasicHost) Start() { h.psManager.Start() if h.autonatv2 != nil { - err := h.autonatv2.Start() + err := h.autonatv2.Start(h) if err != nil { log.Errorf("autonat v2 failed to start: %s", err) } diff --git a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 b/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 new file mode 100644 index 0000000000..9a0935d2d6 --- /dev/null +++ b/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 @@ -0,0 +1,3 @@ +go test fuzz v1 +int(0) +[]byte("9") diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index 66c124890d..26870976fb 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -93,7 +93,7 @@ type AutoNAT struct { // New returns a new AutoNAT instance. // host and dialerHost should have the same dialing capabilities. In case the host doesn't support // a transport, dial back requests for address for that transport will be ignored. -func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { +func New(dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { s := defaultSettings() for _, o := range opts { if err := o(s); err != nil { @@ -103,11 +103,10 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, ctx, cancel := context.WithCancel(context.Background()) an := &AutoNAT{ - host: host, ctx: ctx, cancel: cancel, - srv: newServer(host, dialerHost, s), - cli: newClient(host), + srv: newServer(dialerHost, s), + cli: newClient(), allowPrivateAddrs: s.allowPrivateAddrs, peers: newPeersMap(), throttlePeer: make(map[peer.ID]time.Time), @@ -148,7 +147,8 @@ func (an *AutoNAT) background(sub event.Subscription) { } } -func (an *AutoNAT) Start() error { +func (an *AutoNAT) Start(h host.Host) error { + an.host = h // Listen on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged // event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers. sub, err := an.host.EventBus().Subscribe([]interface{}{ @@ -159,8 +159,8 @@ func (an *AutoNAT) Start() error { if err != nil { return fmt.Errorf("event subscription failed: %w", err) } - an.cli.Start() - an.srv.Start() + an.cli.Start(h) + an.srv.Start(h) an.wg.Add(1) go an.background(sub) diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 578ddccb1e..53e538cbf7 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -4,6 +4,10 @@ import ( "context" "errors" "fmt" + "math" + "math/rand" + "net/netip" + "strings" "sync/atomic" "testing" "time" @@ -37,11 +41,11 @@ func newAutoNAT(t testing.TB, dialer host.Host, opts ...AutoNATOption) *AutoNAT swarm.WithIPv6BlackHoleSuccessCounter(nil)))) } opts = append([]AutoNATOption{withThrottlePeerDuration(0)}, opts...) - an, err := New(h, dialer, opts...) + an, err := New(dialer, opts...) if err != nil { t.Error(err) } - an.Start() + require.NoError(t, an.Start(h)) t.Cleanup(an.Close) return an } @@ -89,7 +93,7 @@ func TestAutoNATPrivateAddr(t *testing.T) { an := newAutoNAT(t, nil) res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) require.Equal(t, res, Result{}) - require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") + require.ErrorIs(t, err, ErrPrivateAddrs) } func TestClientRequest(t *testing.T) { @@ -611,3 +615,97 @@ func TestAreAddrsConsistency(t *testing.T) { }) } } + +func FuzzClient(f *testing.F) { + a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2)) + c := newAutoNAT(f, nil) + idAndWait(f, c, a) + + randProto := func() ma.Multiaddr { + protoTemplates := []string{ + "/tcp/%d/", + "/udp/%d/", + "/udp/%d/quic-v1/", + "/udp/%d/quic-v1/tcp/%d", + "/udp/%d/quic-v1/webtransport/", + "/udp/%d/webrtc/", + "/udp/%d/webrtc-direct/", + "/unix/hello/", + } + s := protoTemplates[rand.Intn(len(protoTemplates))] + if strings.Count(s, "%d") == 1 { + return ma.StringCast(fmt.Sprintf(s, rand.Intn(1000))) + } + return ma.StringCast(fmt.Sprintf(s, rand.Intn(1000), rand.Intn(1000))) + } + + randIP := func() ma.Multiaddr { + x := rand.Intn(2) + if x == 0 { + i := rand.Int31() + ip := netip.AddrFrom4([4]byte{byte(i), byte(i >> 8), byte(i >> 16), byte(i >> 24)}) + return ma.StringCast(fmt.Sprintf("/ip4/%s/tcp/1", ip)) + } + a, b := rand.Int63(), rand.Int63() + ip := netip.AddrFrom16([16]byte{ + byte(a), byte(a >> 8), byte(a >> 16), byte(a >> 24), + byte(a >> 32), byte(a >> 40), byte(a >> 48), byte(a >> 56), + byte(b), byte(b >> 8), byte(b >> 16), byte(b >> 24), + byte(b >> 32), byte(b >> 40), byte(b >> 48), byte(b >> 56), + }) + return ma.StringCast(fmt.Sprintf("/ip6/%s/tcp/1", ip)) + } + + newAddrs := func() ma.Multiaddr { + switch rand.Intn(5) { + case 0: + return randIP().Encapsulate(randProto()) + case 1: + return randProto() + case 2: + return nil + default: + return randProto().Encapsulate(randIP()) + } + } + + randDNSAddr := func(hostName string) ma.Multiaddr { + if len(hostName) == 0 { + panic("wtf") + } + var da ma.Multiaddr + switch rand.Intn(4) { + case 0: + da = ma.StringCast(fmt.Sprintf("/dns/%s/", hostName)) + case 1: + da = ma.StringCast(fmt.Sprintf("/dns4/%s/", hostName)) + case 2: + da = ma.StringCast(fmt.Sprintf("/dns6/%s/", hostName)) + default: + da = ma.StringCast(fmt.Sprintf("/dnsaddr/%s/", hostName)) + } + return da.Encapsulate(randProto()) + } + + // reduce the streamTimeout before running this. TODO: fix this + f.Fuzz(func(t *testing.T, i int, hostNames []byte) { + const maxAddrs = 100 + numAddrs := ((i % maxAddrs) + maxAddrs) % maxAddrs + addrs := make([]ma.Multiaddr, numAddrs) + for i := range numAddrs { + addrs[i] = newAddrs() + } + maxDNSAddrs := 10 + hostNamesStr := strings.ReplaceAll(string(hostNames), "\\", "") + hostNamesStr = strings.ReplaceAll(hostNamesStr, "/", "") + for i := 0; i < len(hostNamesStr) && i < 2*maxDNSAddrs; i += 2 { + ed := min(i+2, len(hostNamesStr)) + addrs = append(addrs, randDNSAddr(hostNamesStr[i:ed])) + } + reqs := make([]Request, len(addrs)) + for i, addr := range addrs { + reqs[i] = Request{Addr: addr, SendDialData: true} + } + c.GetReachability(context.Background(), reqs) + }) +} diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index fea217b1e3..15b454e35f 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -35,20 +35,20 @@ type normalizeMultiaddrer interface { NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr } -func newClient(h host.Host) *client { - normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a } - if hn, ok := h.(normalizeMultiaddrer); ok { - normalizeMultiaddr = hn.NormalizeMultiaddr - } +func newClient() *client { return &client{ - host: h, - dialData: make([]byte, 4000), - normalizeMultiaddr: normalizeMultiaddr, - dialBackQueues: make(map[uint64]chan ma.Multiaddr), + dialData: make([]byte, 4000), + dialBackQueues: make(map[uint64]chan ma.Multiaddr), } } -func (ac *client) Start() { +func (ac *client) Start(h host.Host) { + normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a } + if hn, ok := h.(normalizeMultiaddrer); ok { + normalizeMultiaddr = hn.NormalizeMultiaddr + } + ac.host = h + ac.normalizeMultiaddr = normalizeMultiaddr ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack) } diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index 92bf42cef4..d6e6e73431 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -59,10 +59,9 @@ type server struct { allowPrivateAddrs bool } -func newServer(host, dialer host.Host, s *autoNATSettings) *server { +func newServer(dialer host.Host, s *autoNATSettings) *server { return &server{ dialerHost: dialer, - host: host, dialDataRequestPolicy: s.dataRequestPolicy, amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait, allowPrivateAddrs: s.allowPrivateAddrs, @@ -79,7 +78,8 @@ func newServer(host, dialer host.Host, s *autoNATSettings) *server { } // Enable attaches the stream handler to the host. -func (as *server) Start() { +func (as *server) Start(h host.Host) { + as.host = h as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) } diff --git a/p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/90eb8c62717e4cbe b/p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/90eb8c62717e4cbe new file mode 100644 index 0000000000..5dfcc07bcd --- /dev/null +++ b/p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/90eb8c62717e4cbe @@ -0,0 +1,3 @@ +go test fuzz v1 +int(26) +[]byte("/0") From d4f1b31f430be6ca28a5a39574ef517916436264 Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 22 Apr 2025 23:18:39 +0530 Subject: [PATCH 10/15] linting --- p2p/host/basic/addrs_reachability_tracker_test.go | 1 - p2p/protocol/autonatv2/server_test.go | 1 - 2 files changed, 2 deletions(-) diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index a6ad9377a9..51ee4bc6ae 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -157,7 +157,6 @@ func TestProbeManager(t *testing.T) { } reqs = pm.GetProbe() require.Empty(t, reqs) - }) t.Run("reachabilityUpdate", func(t *testing.T) { diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 73c89a55fc..e2814e2dc7 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -136,7 +136,6 @@ func TestServerInvalidAddrsRejected(t *testing.T) { require.ErrorIs(t, err, network.ErrReset) require.Equal(t, Result{}, res) }) - } func TestServerDataRequest(t *testing.T) { From b00ff48b294144d199dee3996a88f86338b9a065 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 28 Apr 2025 22:01:22 +0530 Subject: [PATCH 11/15] cleanup fuzzing --- p2p/host/basic/addrs_manager.go | 17 ++-- p2p/host/basic/addrs_manager_test.go | 9 +-- .../basic/addrs_reachability_tracker_test.go | 81 ++++++++++++------- .../12df2ba2ca01d0a2 | 3 - p2p/protocol/autonatv2/autonat_test.go | 47 ++++++----- 5 files changed, 86 insertions(+), 71 deletions(-) delete mode 100644 p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index 4114a43558..e6556d14d4 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -245,6 +245,8 @@ func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub even } func (a *addrsManager) updateAddrs() hostAddrs { + // Must lock while doing both recompute and update as this method is called from + // multiple goroutines. a.addrsMx.Lock() defer a.addrsMx.Unlock() @@ -370,10 +372,12 @@ func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAd // Only include host addresses as the reachability manager may have // a stale view of host's addresses. reachableAddrs = slices.DeleteFunc(reachableAddrs, func(a ma.Multiaddr) bool { - return !contains(localAddrs, a) + _, ok := slices.BinarySearchFunc(localAddrs, a, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + return !ok }) unreachableAddrs = slices.DeleteFunc(unreachableAddrs, func(a ma.Multiaddr) bool { - return !contains(localAddrs, a) + _, ok := slices.BinarySearchFunc(localAddrs, a, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + return !ok }) return reachableAddrs, unreachableAddrs } @@ -555,15 +559,6 @@ func areAddrsDifferent(prev, current []ma.Multiaddr) bool { return false } -func contains(addrs []ma.Multiaddr, addr ma.Multiaddr) bool { - for _, a := range addrs { - if a.Equal(addr) { - return true - } - } - return false -} - const interfaceAddrsCacheTTL = time.Minute type interfaceAddrsCache struct { diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index 69d264f8f6..e76957e691 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -436,6 +436,7 @@ func TestAddrsManager(t *testing.T) { func TestAddrsManagerReachabilityEvent(t *testing.T) { publicQUIC, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1") + publicQUIC2, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1235/quic-v1") publicTCP, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") bus := eventbus.NewBus() @@ -447,23 +448,21 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) { am := newAddrsManagerTestCase(t, addrsManagerArgs{ Bus: bus, // currently they aren't being passed to the reachability tracker - ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicTCP} }, + ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicQUIC2, publicTCP} }, AutoNATClient: mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { if reqs[0].Addr.Equal(publicQUIC) { return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil - } else if reqs[0].Addr.Equal(publicTCP) { + } else if reqs[0].Addr.Equal(publicTCP) || reqs[0].Addr.Equal(publicQUIC2) { return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPrivate}, nil } - t.Errorf("received invalid request for addr: %+v", reqs[0]) return autonatv2.Result{}, errors.New("invalid") }, }, }) reachableAddrs := []ma.Multiaddr{publicQUIC} - unreachableAddrs := []ma.Multiaddr{publicTCP} - + unreachableAddrs := []ma.Multiaddr{publicTCP, publicQUIC2} select { case e := <-sub.Out(): evt := e.(event.EvtHostReachableAddrsChanged) diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index 51ee4bc6ae..5dca1ec810 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -2,9 +2,11 @@ package basichost import ( "context" + "encoding/binary" "errors" "fmt" "math/rand" + "net" "net/netip" "strings" "sync" @@ -663,31 +665,50 @@ func BenchmarkAddrTracker(b *testing.B) { } func FuzzAddrsReachabilityTracker(f *testing.F) { - cl := clock.NewMock() + type autonatv2Response struct { + Result autonatv2.Result + Err error + } // The only constraint we force is that result.Idx < len(reqs) client := mockAutoNATClient{ F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { - switch rand.Intn(7) { - case 0: - return autonatv2.Result{AllAddrsRefused: true}, nil - case 1: - return autonatv2.Result{}, errors.New("test error") - case 2: - return autonatv2.Result{}, nil - case 3: - k := rand.Intn(len(reqs)) - r := network.Reachability(rand.Intn(3)) - return autonatv2.Result{Addr: reqs[k].Addr, Idx: k, Reachability: r}, nil - case 4: - return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic, AllAddrsRefused: true}, nil - case 5: - return autonatv2.Result{Addr: reqs[0].Addr, Idx: len(reqs) - 1, Reachability: network.ReachabilityPublic, AllAddrsRefused: true}, nil - default: - return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil + if rand.Intn(3) == 0 { + // some address confirmed + x := rand.Intn(3) + rch := network.Reachability(x) + n := rand.Intn(len(reqs)) + return autonatv2.Result{ + Addr: reqs[n].Addr, + Idx: n, + Reachability: rch, + }, nil + } + outcomes := []autonatv2Response{ + {Result: autonatv2.Result{AllAddrsRefused: true}}, + {Err: errors.New("test error")}, + {Err: autonatv2.ErrPrivateAddrs}, + {Err: autonatv2.ErrNoPeers}, + {Result: autonatv2.Result{}, Err: nil}, + {Result: autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}}, + {Result: autonatv2.Result{ + Addr: reqs[0].Addr, + Idx: 0, + Reachability: network.ReachabilityPublic, + AllAddrsRefused: true, + }}, + {Result: autonatv2.Result{ + Addr: reqs[0].Addr, + Idx: len(reqs) - 1, // invalid idx + Reachability: network.ReachabilityPublic, + AllAddrsRefused: false, + }}, } + outcome := outcomes[rand.Intn(len(outcomes))] + return outcome.Result, outcome.Err }, } + // TODO: Move this to go-multiaddrs randProto := func() ma.Multiaddr { protoTemplates := []string{ "/tcp/%d/", @@ -714,6 +735,10 @@ func FuzzAddrsReachabilityTracker(f *testing.F) { return ma.StringCast(fmt.Sprintf("/ip4/%s/tcp/1", ip)) } a, b := rand.Int63(), rand.Int63() + if rand.Intn(2) == 0 { + pubIP := net.ParseIP("2005::") // Public IP address + a = int64(binary.LittleEndian.Uint64(pubIP[0:8])) + } ip := netip.AddrFrom16([16]byte{ byte(a), byte(a >> 8), byte(a >> 16), byte(a >> 24), byte(a >> 32), byte(a >> 40), byte(a >> 48), byte(a >> 56), @@ -737,22 +762,18 @@ func FuzzAddrsReachabilityTracker(f *testing.F) { } randDNSAddr := func(hostName string) ma.Multiaddr { - var da ma.Multiaddr - switch rand.Intn(4) { - case 0: - da = ma.StringCast(fmt.Sprintf("/dns/%s/", hostName)) - case 1: - da = ma.StringCast(fmt.Sprintf("/dns4/%s/", hostName)) - case 2: - da = ma.StringCast(fmt.Sprintf("/dns6/%s/", hostName)) - default: - da = ma.StringCast(fmt.Sprintf("/dnsaddr/%s/", hostName)) + dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} + if hostName == "" { + hostName = "localhost" } + hostName = strings.ReplaceAll(hostName, "\\", "") + hostName = strings.ReplaceAll(hostName, "/", "") + da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[rand.Intn(len(dnsProtos))], hostName)) return da.Encapsulate(randProto()) } + const maxAddrs = 1000 getAddrs := func(numAddrs int, hostNames []byte) []ma.Multiaddr { - const maxAddrs = 1000 numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs addrs := make([]ma.Multiaddr, numAddrs) for i := range numAddrs { @@ -766,9 +787,9 @@ func FuzzAddrsReachabilityTracker(f *testing.F) { return addrs } + cl := clock.NewMock() f.Fuzz(func(t *testing.T, i int, hostNames []byte) { tr := newAddrsReachabilityTracker(client, nil, cl) - require.NoError(t, tr.Start()) tr.UpdateAddrs(getAddrs(i, hostNames)) diff --git a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 b/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 deleted file mode 100644 index 9a0935d2d6..0000000000 --- a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/12df2ba2ca01d0a2 +++ /dev/null @@ -1,3 +0,0 @@ -go test fuzz v1 -int(0) -[]byte("9") diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 53e538cbf7..73111adbe5 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -2,10 +2,12 @@ package autonatv2 import ( "context" + "encoding/binary" "errors" "fmt" "math" "math/rand" + "net" "net/netip" "strings" "sync/atomic" @@ -621,6 +623,7 @@ func FuzzClient(f *testing.F) { c := newAutoNAT(f, nil) idAndWait(f, c, a) + // TODO: Move this to go-multiaddrs randProto := func() ma.Multiaddr { protoTemplates := []string{ "/tcp/%d/", @@ -647,6 +650,10 @@ func FuzzClient(f *testing.F) { return ma.StringCast(fmt.Sprintf("/ip4/%s/tcp/1", ip)) } a, b := rand.Int63(), rand.Int63() + if rand.Intn(2) == 0 { + pubIP := net.ParseIP("2005::") // Public IP address + a = int64(binary.LittleEndian.Uint64(pubIP[0:8])) + } ip := netip.AddrFrom16([16]byte{ byte(a), byte(a >> 8), byte(a >> 16), byte(a >> 24), byte(a >> 32), byte(a >> 40), byte(a >> 48), byte(a >> 56), @@ -670,38 +677,34 @@ func FuzzClient(f *testing.F) { } randDNSAddr := func(hostName string) ma.Multiaddr { - if len(hostName) == 0 { - panic("wtf") - } - var da ma.Multiaddr - switch rand.Intn(4) { - case 0: - da = ma.StringCast(fmt.Sprintf("/dns/%s/", hostName)) - case 1: - da = ma.StringCast(fmt.Sprintf("/dns4/%s/", hostName)) - case 2: - da = ma.StringCast(fmt.Sprintf("/dns6/%s/", hostName)) - default: - da = ma.StringCast(fmt.Sprintf("/dnsaddr/%s/", hostName)) + dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} + if hostName == "" { + hostName = "localhost" } + hostName = strings.ReplaceAll(hostName, "\\", "") + hostName = strings.ReplaceAll(hostName, "/", "") + da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[rand.Intn(len(dnsProtos))], hostName)) return da.Encapsulate(randProto()) } - // reduce the streamTimeout before running this. TODO: fix this - f.Fuzz(func(t *testing.T, i int, hostNames []byte) { - const maxAddrs = 100 - numAddrs := ((i % maxAddrs) + maxAddrs) % maxAddrs + const maxAddrs = 1000 + getAddrs := func(numAddrs int, hostNames []byte) []ma.Multiaddr { + numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs addrs := make([]ma.Multiaddr, numAddrs) for i := range numAddrs { addrs[i] = newAddrs() } maxDNSAddrs := 10 - hostNamesStr := strings.ReplaceAll(string(hostNames), "\\", "") - hostNamesStr = strings.ReplaceAll(hostNamesStr, "/", "") - for i := 0; i < len(hostNamesStr) && i < 2*maxDNSAddrs; i += 2 { - ed := min(i+2, len(hostNamesStr)) - addrs = append(addrs, randDNSAddr(hostNamesStr[i:ed])) + for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 { + ed := min(i+2, len(hostNames)) + addrs = append(addrs, randDNSAddr(string(hostNames[i:ed]))) } + return addrs + } + + // reduce the streamTimeout before running this. TODO: fix this + f.Fuzz(func(t *testing.T, numAddrs int, hostNames []byte) { + addrs := getAddrs(numAddrs, hostNames) reqs := make([]Request, len(addrs)) for i, addr := range addrs { reqs[i] = Request{Addr: addr, SendDialData: true} From 0b46ec6f61fde4d51d2272ec5c7791ba91d8d609 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 28 Apr 2025 23:13:34 +0530 Subject: [PATCH 12/15] fix sorted diff --- p2p/host/basic/addrs_manager.go | 88 ++++++++++++------- p2p/host/basic/addrs_manager_test.go | 40 +++++++++ .../basic/addrs_reachability_tracker_test.go | 33 ++++++- p2p/protocol/autonatv2/autonat_test.go | 6 +- .../testdata/fuzz/FuzzClient/7823777b04c6fb64 | 3 + 5 files changed, 130 insertions(+), 40 deletions(-) create mode 100644 p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/7823777b04c6fb64 diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index e6556d14d4..bc0c75425f 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -53,14 +53,14 @@ type addrsManager struct { triggerAddrsUpdateChan chan struct{} // triggerReachabilityUpdate is notified when reachable addrs are updated. triggerReachabilityUpdate chan struct{} - // triggerHostReachabilityUpdate is notified when host's reachability from autonat v1 changes. - triggerHostReachabilityUpdate chan struct{} hostReachability atomic.Pointer[network.Reachability] addrsMx sync.RWMutex // protects fields below currentAddrs hostAddrs - relayAddrs []ma.Multiaddr + // relayAddrs are the host's relay addresses. Kept separate from hostAddrs as we + // update them differently from hostAddrs. + relayAddrs []ma.Multiaddr wg sync.WaitGroup ctx context.Context @@ -79,19 +79,18 @@ func newAddrsManager( ) (*addrsManager, error) { ctx, cancel := context.WithCancel(context.Background()) as := &addrsManager{ - bus: bus, - listenAddrs: listenAddrs, - transportForListening: transportForListening, - observedAddrsManager: observedAddrsManager, - natManager: natmgr, - addrsFactory: addrsFactory, - triggerAddrsUpdateChan: make(chan struct{}, 1), - triggerHostReachabilityUpdate: make(chan struct{}, 1), - triggerReachabilityUpdate: make(chan struct{}, 1), - addrsUpdatedChan: addrsUpdatedChan, - interfaceAddrs: &interfaceAddrsCache{}, - ctx: ctx, - ctxCancel: cancel, + bus: bus, + listenAddrs: listenAddrs, + transportForListening: transportForListening, + observedAddrsManager: observedAddrsManager, + natManager: natmgr, + addrsFactory: addrsFactory, + triggerAddrsUpdateChan: make(chan struct{}, 1), + triggerReachabilityUpdate: make(chan struct{}, 1), + addrsUpdatedChan: addrsUpdatedChan, + interfaceAddrs: &interfaceAddrsCache{}, + ctx: ctx, + ctxCancel: cancel, } unknownReachability := network.ReachabilityUnknown as.hostReachability.Store(&unknownReachability) @@ -229,11 +228,11 @@ func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub even select { case <-ticker.C: case <-a.triggerAddrsUpdateChan: + case <-a.triggerReachabilityUpdate: case e := <-autoRelayAddrsSub.Out(): if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { a.updateRelayAddrs(evt.RelayAddrs) } - case <-a.triggerReachabilityUpdate: case e := <-autonatReachabilitySub.Out(): if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { a.hostReachability.Store(&evt.Reachability) @@ -272,6 +271,12 @@ func (a *addrsManager) updateAddrs() hostAddrs { } } +func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { + a.addrsMx.Lock() + defer a.addrsMx.Unlock() + a.relayAddrs = append(a.relayAddrs[:0], addrs...) +} + func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) { if areAddrsDifferent(previous.localAddrs, current.localAddrs) { log.Debugf("host local addresses updated: %s", current.localAddrs) @@ -325,6 +330,8 @@ func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multi return addrs } +// HolePunchAddrs returns all the host's direct public addresses, reachable or unreachable, +// suitable for hole punching. func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr { addrs := a.DirectAddrs() addrs = slices.Clone(a.addrsFactory(addrs)) @@ -351,6 +358,7 @@ func (a *addrsManager) ReachableAddrs() []ma.Multiaddr { return slices.Clone(a.currentAddrs.reachableAddrs) } +// RelayAddrs returns all the relay addresses of the host. func (a *addrsManager) RelayAddrs() []ma.Multiaddr { a.addrsMx.RLock() defer a.addrsMx.RUnlock() @@ -361,25 +369,37 @@ func (a *addrsManager) getRelayAddrsUnlocked() []ma.Multiaddr { return slices.Clone(a.relayAddrs) } -func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { - a.addrsMx.Lock() - defer a.addrsMx.Unlock() - a.relayAddrs = append(a.relayAddrs[:0], addrs...) -} - func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { reachableAddrs, unreachableAddrs = a.addrsReachabilityTracker.ConfirmedAddrs() - // Only include host addresses as the reachability manager may have - // a stale view of host's addresses. - reachableAddrs = slices.DeleteFunc(reachableAddrs, func(a ma.Multiaddr) bool { - _, ok := slices.BinarySearchFunc(localAddrs, a, func(a, b ma.Multiaddr) int { return a.Compare(b) }) - return !ok - }) - unreachableAddrs = slices.DeleteFunc(unreachableAddrs, func(a ma.Multiaddr) bool { - _, ok := slices.BinarySearchFunc(localAddrs, a, func(a, b ma.Multiaddr) int { return a.Compare(b) }) - return !ok - }) - return reachableAddrs, unreachableAddrs + return removeIfNotInSource(reachableAddrs, localAddrs), removeIfNotInSource(unreachableAddrs, localAddrs) +} + +// removeIfNotInSource removes items from addrs that are not present in source. +// Modifies the addrs slice in place +// addrs and source must be sorted using multiaddr.Compare. +func removeIfNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr { + j := 0 + // mark entries not in source as nil + for i, a := range addrs { + // move right till a is greater + for j < len(source) && a.Compare(source[j]) > 0 { + j++ + } + // a is not in source if we've reached the end, or a is lesser + if j == len(source) || a.Compare(source[j]) < 0 { + addrs[i] = nil + } + // a is in source, nothing to do + } + // j is the current element, i is the lowest index nil element + i := 0 + for j := 0; j < len(addrs); j++ { + if addrs[j] != nil { + addrs[i], addrs[j] = addrs[j], addrs[i] + i++ + } + } + return addrs[:i] } var p2pCircuitAddr = ma.StringCast("/p2p-circuit") diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index e76957e691..41f1e44b1e 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "testing" "time" @@ -474,6 +475,33 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) { } } +func TestRemoveIfNotInSource(t *testing.T) { + var addrs []ma.Multiaddr + for i := 0; i < 10; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", i))) + } + slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + cases := []struct { + addrs []ma.Multiaddr + source []ma.Multiaddr + expected []ma.Multiaddr + }{ + {}, + {addrs: slices.Clone(addrs[:5]), source: nil, expected: nil}, + {addrs: nil, source: addrs, expected: nil}, + {addrs: []ma.Multiaddr{addrs[0]}, source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}}, + {addrs: slices.Clone(addrs), source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}}, + {addrs: slices.Clone(addrs), source: slices.Clone(addrs[5:]), expected: slices.Clone(addrs[5:])}, + {addrs: slices.Clone(addrs[:5]), source: []ma.Multiaddr{addrs[0], addrs[2], addrs[8]}, expected: []ma.Multiaddr{addrs[0], addrs[2]}}, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + addrs := removeIfNotInSource(tc.addrs, tc.source) + require.ElementsMatch(t, tc.expected, addrs, "%s\n%s", tc.expected, tc.addrs) + }) + } +} + func BenchmarkAreAddrsDifferent(b *testing.B) { var addrs [10]ma.Multiaddr for i := 0; i < len(addrs); i++ { @@ -487,3 +515,15 @@ func BenchmarkAreAddrsDifferent(b *testing.B) { } }) } + +func BenchmarkRemoveIfNotInSource(b *testing.B) { + var addrs [10]ma.Multiaddr + for i := 0; i < len(addrs); i++ { + addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i)) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + removeIfNotInSource(slices.Clone(addrs[:5]), addrs[:]) + } +} diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index 5dca1ec810..6ea0f36fd7 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -8,6 +8,7 @@ import ( "math/rand" "net" "net/netip" + "slices" "strings" "sync" "sync/atomic" @@ -271,6 +272,32 @@ func TestAddrsReachabilityTracker(t *testing.T) { require.Empty(t, unreachable) }) + t.Run("confirmed addrs ordering", func(t *testing.T) { + mockClient := mockAutoNATClient{ + F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil + }, + } + tr := newTracker(mockClient, nil) + var addrs []ma.Multiaddr + for i := 0; i < 10; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", i))) + } + slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return -a.Compare(b) }) // sort in reverse order + tr.UpdateAddrs(addrs) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(2 * time.Second): + t.Fatal("expected reachability update") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Empty(t, unreachable) + + orderedAddrs := slices.Clone(addrs) + slices.Reverse(orderedAddrs) + require.Equal(t, reachable, orderedAddrs, "%s %s", reachable, addrs) + }) + t.Run("backoff", func(t *testing.T) { notify := make(chan struct{}, 1) drainNotify := func() bool { @@ -762,12 +789,12 @@ func FuzzAddrsReachabilityTracker(f *testing.F) { } randDNSAddr := func(hostName string) ma.Multiaddr { - dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} + hostName = strings.ReplaceAll(hostName, "\\", "") + hostName = strings.ReplaceAll(hostName, "/", "") if hostName == "" { hostName = "localhost" } - hostName = strings.ReplaceAll(hostName, "\\", "") - hostName = strings.ReplaceAll(hostName, "/", "") + dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[rand.Intn(len(dnsProtos))], hostName)) return da.Encapsulate(randProto()) } diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 73111adbe5..991c03b55d 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -677,12 +677,12 @@ func FuzzClient(f *testing.F) { } randDNSAddr := func(hostName string) ma.Multiaddr { - dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} + hostName = strings.ReplaceAll(hostName, "\\", "") + hostName = strings.ReplaceAll(hostName, "/", "") if hostName == "" { hostName = "localhost" } - hostName = strings.ReplaceAll(hostName, "\\", "") - hostName = strings.ReplaceAll(hostName, "/", "") + dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[rand.Intn(len(dnsProtos))], hostName)) return da.Encapsulate(randProto()) } diff --git a/p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/7823777b04c6fb64 b/p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/7823777b04c6fb64 new file mode 100644 index 0000000000..e72a0836b7 --- /dev/null +++ b/p2p/protocol/autonatv2/testdata/fuzz/FuzzClient/7823777b04c6fb64 @@ -0,0 +1,3 @@ +go test fuzz v1 +int(-22) +[]byte("/") From d048ba76c06492329c561ac0338519e0273fea9a Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 14 May 2025 15:34:57 +0530 Subject: [PATCH 13/15] review comments --- .golangci.yml | 2 + p2p/host/basic/addrs_manager.go | 127 +++--- p2p/host/basic/addrs_manager_test.go | 22 +- p2p/host/basic/addrs_reachability_tracker.go | 385 ++++++++---------- .../basic/addrs_reachability_tracker_test.go | 70 ++-- .../4f31b7942ec62406 | 6 - .../53e52cff547ff885 | 6 - .../79485637e486f9db | 6 - p2p/protocol/autonatv2/autonat.go | 12 +- p2p/protocol/autonatv2/autonat_test.go | 47 ++- p2p/protocol/autonatv2/client.go | 6 +- 11 files changed, 349 insertions(+), 340 deletions(-) delete mode 100644 p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/4f31b7942ec62406 delete mode 100644 p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/53e52cff547ff885 delete mode 100644 p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/79485637e486f9db diff --git a/.golangci.yml b/.golangci.yml index 645266d523..eac7bfe3fd 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -8,6 +8,8 @@ linters: - revive - unused - prealloc + disable: + - errcheck settings: revive: diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index bc0c75425f..6c870f43cc 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -34,6 +34,7 @@ type hostAddrs struct { localAddrs []ma.Multiaddr reachableAddrs []ma.Multiaddr unreachableAddrs []ma.Multiaddr + relayAddrs []ma.Multiaddr } type addrsManager struct { @@ -56,11 +57,8 @@ type addrsManager struct { hostReachability atomic.Pointer[network.Reachability] - addrsMx sync.RWMutex // protects fields below + addrsMx sync.RWMutex currentAddrs hostAddrs - // relayAddrs are the host's relay addresses. Kept separate from hostAddrs as we - // update them differently from hostAddrs. - relayAddrs []ma.Multiaddr wg sync.WaitGroup ctx context.Context @@ -140,7 +138,7 @@ func (a *addrsManager) NetNotifee() network.Notifiee { } func (a *addrsManager) triggerAddrsUpdate() { - a.updateAddrs() + a.updateAddrs(false, nil) select { case a.triggerAddrsUpdateChan <- struct{}{}: default: @@ -177,11 +175,12 @@ func (a *addrsManager) startBackgroundWorker() error { return errors.Join(err, err1, err2) } + var relayAddrs []ma.Multiaddr // update relay addrs in case we're private select { case e := <-autoRelayAddrsSub.Out(): if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { - a.updateRelayAddrs(evt.RelayAddrs) + relayAddrs = slices.Clone(evt.RelayAddrs) } default: } @@ -195,24 +194,23 @@ func (a *addrsManager) startBackgroundWorker() error { } // update addresses before starting the worker loop. This ensures that any address updates // before calling addrsManager.Start are correctly reported after Start returns. - a.updateAddrs() + a.updateAddrs(true, relayAddrs) + a.wg.Add(1) - go func() { - defer a.wg.Done() - a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter) - }() + go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, relayAddrs) return nil } -func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription, emitter event.Emitter) { +func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription, + emitter event.Emitter, relayAddrs []ma.Multiaddr, +) { + defer a.wg.Done() defer func() { err := autoRelayAddrsSub.Close() if err != nil { log.Warnf("error closing auto relay addrs sub: %s", err) } - }() - defer func() { - err := autonatReachabilitySub.Close() + err = autonatReachabilitySub.Close() if err != nil { log.Warnf("error closing autonat reachability sub: %s", err) } @@ -222,7 +220,7 @@ func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub even defer ticker.Stop() var previousAddrs hostAddrs for { - currAddrs := a.updateAddrs() + currAddrs := a.updateAddrs(true, relayAddrs) a.notifyAddrsChanged(emitter, previousAddrs, currAddrs) previousAddrs = currAddrs select { @@ -231,7 +229,7 @@ func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub even case <-a.triggerReachabilityUpdate: case e := <-autoRelayAddrsSub.Out(): if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { - a.updateRelayAddrs(evt.RelayAddrs) + relayAddrs = slices.Clone(evt.RelayAddrs) } case e := <-autonatReachabilitySub.Out(): if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { @@ -243,7 +241,9 @@ func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub even } } -func (a *addrsManager) updateAddrs() hostAddrs { +// updateAddrs updates the addresses of the host and returns the new updated +// addrs +func (a *addrsManager) updateAddrs(updateRelayAddrs bool, relayAddrs []ma.Multiaddr) hostAddrs { // Must lock while doing both recompute and update as this method is called from // multiple goroutines. a.addrsMx.Lock() @@ -254,13 +254,20 @@ func (a *addrsManager) updateAddrs() hostAddrs { if a.addrsReachabilityTracker != nil { currReachableAddrs, currUnreachableAddrs = a.getConfirmedAddrs(localAddrs) } - currAddrs := a.getAddrs(slices.Clone(localAddrs), a.getRelayAddrsUnlocked()) + if !updateRelayAddrs { + relayAddrs = a.currentAddrs.relayAddrs + } else { + // Copy the callers slice + relayAddrs = slices.Clone(relayAddrs) + } + currAddrs := a.getAddrs(slices.Clone(localAddrs), relayAddrs) a.currentAddrs = hostAddrs{ addrs: append(a.currentAddrs.addrs[:0], currAddrs...), localAddrs: append(a.currentAddrs.localAddrs[:0], localAddrs...), reachableAddrs: append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...), unreachableAddrs: append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...), + relayAddrs: append(a.currentAddrs.relayAddrs[:0], relayAddrs...), } return hostAddrs{ @@ -268,15 +275,10 @@ func (a *addrsManager) updateAddrs() hostAddrs { addrs: currAddrs, reachableAddrs: currReachableAddrs, unreachableAddrs: currUnreachableAddrs, + relayAddrs: relayAddrs, } } -func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { - a.addrsMx.Lock() - defer a.addrsMx.Unlock() - a.relayAddrs = append(a.relayAddrs[:0], addrs...) -} - func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) { if areAddrsDifferent(previous.localAddrs, current.localAddrs) { log.Debugf("host local addresses updated: %s", current.localAddrs) @@ -308,7 +310,11 @@ func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, curre // If autorelay is enabled and node reachability is private, it returns // the node's relay addresses and private network addresses. func (a *addrsManager) Addrs() []ma.Multiaddr { - return a.getAddrs(a.DirectAddrs(), a.RelayAddrs()) + a.addrsMx.RLock() + directAddrs := slices.Clone(a.currentAddrs.localAddrs) + relayAddrs := slices.Clone(a.currentAddrs.relayAddrs) + a.addrsMx.RUnlock() + return a.getAddrs(directAddrs, relayAddrs) } // getAddrs returns the node's dialable addresses. Mutates localAddrs @@ -358,48 +364,9 @@ func (a *addrsManager) ReachableAddrs() []ma.Multiaddr { return slices.Clone(a.currentAddrs.reachableAddrs) } -// RelayAddrs returns all the relay addresses of the host. -func (a *addrsManager) RelayAddrs() []ma.Multiaddr { - a.addrsMx.RLock() - defer a.addrsMx.RUnlock() - return a.getRelayAddrsUnlocked() -} - -func (a *addrsManager) getRelayAddrsUnlocked() []ma.Multiaddr { - return slices.Clone(a.relayAddrs) -} - func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { reachableAddrs, unreachableAddrs = a.addrsReachabilityTracker.ConfirmedAddrs() - return removeIfNotInSource(reachableAddrs, localAddrs), removeIfNotInSource(unreachableAddrs, localAddrs) -} - -// removeIfNotInSource removes items from addrs that are not present in source. -// Modifies the addrs slice in place -// addrs and source must be sorted using multiaddr.Compare. -func removeIfNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr { - j := 0 - // mark entries not in source as nil - for i, a := range addrs { - // move right till a is greater - for j < len(source) && a.Compare(source[j]) > 0 { - j++ - } - // a is not in source if we've reached the end, or a is lesser - if j == len(source) || a.Compare(source[j]) < 0 { - addrs[i] = nil - } - // a is in source, nothing to do - } - // j is the current element, i is the lowest index nil element - i := 0 - for j := 0; j < len(addrs); j++ { - if addrs[j] != nil { - addrs[i], addrs[j] = addrs[j], addrs[i] - i++ - } - } - return addrs[:i] + return removeNotInSource(reachableAddrs, localAddrs), removeNotInSource(unreachableAddrs, localAddrs) } var p2pCircuitAddr = ma.StringCast("/p2p-circuit") @@ -701,3 +668,31 @@ func (i *interfaceAddrsCache) updateUnlocked() { } } } + +// removeNotInSource removes items from addrs that are not present in source. +// Modifies the addrs slice in place +// addrs and source must be sorted using multiaddr.Compare. +func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr { + j := 0 + // mark entries not in source as nil + for i, a := range addrs { + // move right till a is greater + for j < len(source) && a.Compare(source[j]) > 0 { + j++ + } + // a is not in source if we've reached the end, or a is lesser + if j == len(source) || a.Compare(source[j]) < 0 { + addrs[i] = nil + } + // a is in source, nothing to do + } + // j is the current element, i is the lowest index nil element + i := 0 + for j := range len(addrs) { + if addrs[j] != nil { + addrs[i], addrs[j] = addrs[j], addrs[i] + i++ + } + } + return addrs[:i] +} diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index 41f1e44b1e..56f9faaf42 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -34,7 +34,7 @@ func TestAppendNATAddrs(t *testing.T) { // nat mapping success, obsaddress ignored Listen: ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1"), Nat: ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1"), - ObsAddrFunc: func(m ma.Multiaddr) []ma.Multiaddr { + ObsAddrFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/udp/100/quic-v1")} }, Expected: []ma.Multiaddr{ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1")}, @@ -120,7 +120,7 @@ func TestAppendNATAddrs(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { as := &addrsManager{ natManager: &mockNatManager{ - GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr { + GetMappingFunc: func(_ ma.Multiaddr) ma.Multiaddr { return tc.Nat }, }, @@ -139,7 +139,7 @@ type mockNatManager struct { GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr } -func (m *mockNatManager) Close() error { +func (*mockNatManager) Close() error { return nil } @@ -150,7 +150,7 @@ func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { return m.GetMappingFunc(addr) } -func (m *mockNatManager) HasDiscoveredNAT() bool { +func (*mockNatManager) HasDiscoveredNAT() bool { return true } @@ -336,7 +336,7 @@ func TestAddrsManager(t *testing.T) { } am := newAddrsManagerTestCase(t, addrsManagerArgs{ ObservedAddrsManager: &mockObservedAddrs{ - ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return quicAddrs }, }, @@ -352,7 +352,7 @@ func TestAddrsManager(t *testing.T) { t.Run("public addrs removed when private", func(t *testing.T) { am := newAddrsManagerTestCase(t, addrsManagerArgs{ ObservedAddrsManager: &mockObservedAddrs{ - ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return []ma.Multiaddr{publicQUIC} }, }, @@ -394,7 +394,7 @@ func TestAddrsManager(t *testing.T) { return nil }, ObservedAddrsManager: &mockObservedAddrs{ - ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return []ma.Multiaddr{publicQUIC} }, }, @@ -414,7 +414,7 @@ func TestAddrsManager(t *testing.T) { t.Run("updates addresses on signaling", func(t *testing.T) { updateChan := make(chan struct{}) am := newAddrsManagerTestCase(t, addrsManagerArgs{ - AddrsFactory: func(addrs []ma.Multiaddr) []ma.Multiaddr { + AddrsFactory: func(_ []ma.Multiaddr) []ma.Multiaddr { select { case <-updateChan: return []ma.Multiaddr{publicQUIC} @@ -451,7 +451,7 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) { // currently they aren't being passed to the reachability tracker ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicQUIC2, publicTCP} }, AutoNATClient: mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { if reqs[0].Addr.Equal(publicQUIC) { return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil } else if reqs[0].Addr.Equal(publicTCP) || reqs[0].Addr.Equal(publicQUIC2) { @@ -496,7 +496,7 @@ func TestRemoveIfNotInSource(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - addrs := removeIfNotInSource(tc.addrs, tc.source) + addrs := removeNotInSource(tc.addrs, tc.source) require.ElementsMatch(t, tc.expected, addrs, "%s\n%s", tc.expected, tc.addrs) }) } @@ -524,6 +524,6 @@ func BenchmarkRemoveIfNotInSource(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - removeIfNotInSource(slices.Clone(addrs[:5]), addrs[:]) + removeNotInSource(slices.Clone(addrs[:5]), addrs[:]) } } diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index 7e091032a3..0dc89276a0 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -7,6 +7,7 @@ import ( "math" "slices" "sync" + "sync/atomic" "time" "github.com/benbjohnson/clock" @@ -33,6 +34,9 @@ const ( newAddrsProbeDelay = 1 * time.Second ) +// addrsReachabilityTracker tracks reachability for addresses. +// Use UpdateAddrs to provide addresses for tracking reachability. +// reachabilityUpdateCh is notified when reachability for any of the tracked address changes. type addrsReachabilityTracker struct { ctx context.Context cancel context.CancelFunc @@ -52,10 +56,8 @@ type addrsReachabilityTracker struct { unreachableAddrs []ma.Multiaddr } -// newAddrsReachabilityTracker tracks reachability for addresses. -// Use UpdateAddrs to provide addresses for tracking reachability. -// reachabilityUpdateCh is notified when any reachability probes are made. The reader must dedup the events. It may be -// notified even when the reachability for any addrs has not changed. +// newAddrsReachabilityTracker returns a new addrsReachabilityTracker. +// reachabilityUpdateCh is notified when reachability for any of the tracked address changes. func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh chan struct{}, cl clock.Clock) *addrsReachabilityTracker { ctx, cancel := context.WithCancel(context.Background()) if cl == nil { @@ -75,7 +77,10 @@ func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh ch } func (r *addrsReachabilityTracker) UpdateAddrs(addrs []ma.Multiaddr) { - r.newAddrs <- slices.Clone(addrs) + select { + case r.newAddrs <- slices.Clone(addrs): + case <-r.ctx.Done(): + } } func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) { @@ -86,10 +91,7 @@ func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachable func (r *addrsReachabilityTracker) Start() error { r.wg.Add(1) - err := r.background() - if err != nil { - return err - } + go r.background() return nil } @@ -100,82 +102,87 @@ func (r *addrsReachabilityTracker) Close() error { } const ( - defaultResetInterval = 5 * time.Minute - maxBackoffInterval = 5 * time.Minute + // defaultReachabilityRefreshInterval is the default interval to refresh reachability. + // In steady state, we check for any required probes every refresh interval. + // This doesn't mean we'll probe for any particular address, only that we'll check + // if any address needs to be probed. + defaultReachabilityRefreshInterval = 5 * time.Minute + // maxBackoffInterval is the maximum back off in case we're unable to probe for reachability. + // We may be unable to confirm addresses in case there are no valid peers with autonatv2 + // or the autonatv2 subsystem is consistently erroring. + maxBackoffInterval = 5 * time.Minute + // backoffStartInterval is the initial back off in case we're unable to probe for reachability. backoffStartInterval = 5 * time.Second ) -func (r *addrsReachabilityTracker) background() error { - go func() { - defer r.wg.Done() +func (r *addrsReachabilityTracker) background() { + defer r.wg.Done() - // probeTicker is used to trigger probes at regular intervals - probeTicker := r.clock.Ticker(defaultResetInterval) - defer probeTicker.Stop() - - // probeTimer is used to trigger probes at specific times - probeTimer := r.clock.Timer(time.Duration(math.MaxInt64)) - defer probeTimer.Stop() - nextProbeTime := time.Time{} - - var task reachabilityTask - var backoffInterval time.Duration - var currReachable, currUnreachable, prevReachable, prevUnreachable []ma.Multiaddr - for { - select { - case <-probeTicker.C: - // don't start a probe if we have a scheduled probe - if task.RespCh == nil && nextProbeTime.IsZero() { - task = r.refreshReachability() - } - case <-probeTimer.C: - if task.RespCh == nil { - task = r.refreshReachability() - } - nextProbeTime = time.Time{} - case backoff := <-task.RespCh: + // probeTicker is used to trigger probes at regular intervals + probeTicker := r.clock.Ticker(defaultReachabilityRefreshInterval) + defer probeTicker.Stop() + + // probeTimer is used to trigger probes at specific times + probeTimer := r.clock.Timer(time.Duration(math.MaxInt64)) + defer probeTimer.Stop() + nextProbeTime := time.Time{} + + var task reachabilityTask + var backoffInterval time.Duration + var currReachable, currUnreachable, prevReachable, prevUnreachable []ma.Multiaddr + for { + select { + case <-probeTicker.C: + // don't start a probe if we have a scheduled probe + if task.RespCh == nil && nextProbeTime.IsZero() { + task = r.refreshReachability() + } + case <-probeTimer.C: + if task.RespCh == nil { + task = r.refreshReachability() + } + nextProbeTime = time.Time{} + case backoff := <-task.RespCh: + task = reachabilityTask{} + // On completion, start the next probe immediately, or wait for backoff. + // In case there are no further probes, the reachability tracker will return an empty task, + // which hangs forever. Eventually, we'll refresh again when the ticker fires. + if backoff { + backoffInterval = newBackoffInterval(backoffInterval) + } else { + backoffInterval = -1 * time.Second // negative to trigger next probe immediately + } + nextProbeTime = r.clock.Now().Add(backoffInterval) + case addrs := <-r.newAddrs: + if task.RespCh != nil { // cancel running task. + task.Cancel() + <-task.RespCh // ignore backoff from cancelled task task = reachabilityTask{} - // On completion, start the next probe immediately, or wait for backoff - // In case there are no further probes, the reachability tracker will return an empty task, - // which hangs forever. Eventually, we'll refresh again when the ticker fires. - if backoff { - backoffInterval = newBackoffInterval(backoffInterval) - } else { - backoffInterval = -1 * time.Second // negative to trigger next probe immediately - } - nextProbeTime = r.clock.Now().Add(backoffInterval) - case addrs := <-r.newAddrs: - if task.RespCh != nil { // cancel running task. - task.Cancel() - <-task.RespCh // ignore backoff from cancelled task - task = reachabilityTask{} - } - r.updateTrackedAddrs(addrs) - newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay) - if nextProbeTime.Before(newAddrsNextTime) { - nextProbeTime = newAddrsNextTime - } - case <-r.ctx.Done(): - if task.RespCh != nil { - task.Cancel() - <-task.RespCh - task = reachabilityTask{} - } - return } - - currReachable, currUnreachable = r.appendConfirmedAddrs(currReachable[:0], currUnreachable[:0]) - if areAddrsDifferent(prevReachable, currReachable) || areAddrsDifferent(prevUnreachable, currUnreachable) { - r.notify() + r.updateTrackedAddrs(addrs) + newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay) + if nextProbeTime.Before(newAddrsNextTime) { + nextProbeTime = newAddrsNextTime } - prevReachable = append(prevReachable[:0], currReachable...) - prevUnreachable = append(prevUnreachable[:0], currUnreachable...) - if !nextProbeTime.IsZero() { - probeTimer.Reset(nextProbeTime.Sub(r.clock.Now())) + case <-r.ctx.Done(): + if task.RespCh != nil { + task.Cancel() + <-task.RespCh + task = reachabilityTask{} } + return } - }() - return nil + + currReachable, currUnreachable = r.appendConfirmedAddrs(currReachable[:0], currUnreachable[:0]) + if areAddrsDifferent(prevReachable, currReachable) || areAddrsDifferent(prevUnreachable, currUnreachable) { + r.notify() + } + prevReachable = append(prevReachable[:0], currReachable...) + prevUnreachable = append(prevUnreachable[:0], currUnreachable...) + if !nextProbeTime.IsZero() { + probeTimer.Reset(nextProbeTime.Sub(r.clock.Now())) + } + } } func newBackoffInterval(current time.Duration) time.Duration { @@ -268,12 +275,6 @@ func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv return res, err } -type probeResponse struct { - Req []autonatv2.Request - Res autonatv2.Result - Err error -} - const maxConsecutiveErrors = 20 // runProbes runs probes provided by addrsTracker with the given client. It returns true if the caller should @@ -284,55 +285,34 @@ const maxConsecutiveErrors = 20 // - the client has no valid peers to probe func runProbes(ctx context.Context, concurrency int, addrsTracker *probeManager, client autonatv2Client) bool { client = &errCountingClient{autonatv2Client: client, MaxConsecutiveErrors: maxConsecutiveErrors} - - resultsCh := make(chan probeResponse, 2*concurrency) // enough buffer to allow all worker goroutines to exit quickly - jobsCh := make(chan []autonatv2.Request, 1) // close jobs to terminate the workers + var backoff atomic.Bool var wg sync.WaitGroup wg.Add(concurrency) for range concurrency { go func() { defer wg.Done() - for reqs := range jobsCh { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - res, err := client.GetReachability(ctx, reqs) + for { + if ctx.Err() != nil { + return + } + reqs := addrsTracker.GetProbe() + if len(reqs) == 0 { + return + } + rctx, cancel := context.WithTimeout(ctx, 30*time.Second) + addrsTracker.MarkProbeInProgress(reqs) + res, err := client.GetReachability(rctx, reqs) cancel() - resultsCh <- probeResponse{Req: reqs, Res: res, Err: err} + addrsTracker.CompleteProbe(reqs, res, err) + if isErrorPersistent(err) { + backoff.Store(true) + return + } } }() } - - nextProbe := addrsTracker.GetProbe() - backoff := false -outer: - for jc := jobsCh; addrsTracker.InProgressProbes() > 0 || len(nextProbe) > 0; { - select { - case jc <- nextProbe: - addrsTracker.MarkProbeInProgress(nextProbe) - case resp := <-resultsCh: - addrsTracker.CompleteProbe(resp.Req, resp.Res, resp.Err) - if isErrorPersistent(resp.Err) { - backoff = true - break outer - } - case <-ctx.Done(): - break outer - } - jc = jobsCh - nextProbe = addrsTracker.GetProbe() - if len(nextProbe) == 0 { - jc = nil - } - } - close(jobsCh) - for addrsTracker.InProgressProbes() > 0 { - resp := <-resultsCh - addrsTracker.CompleteProbe(resp.Req, resp.Res, resp.Err) - if isErrorPersistent(resp.Err) { - backoff = true - } - } wg.Wait() - return backoff + return backoff.Load() } // isErrorPersistent returns whether the error will repeat on future probes for a while @@ -357,7 +337,7 @@ const ( maxRecentDialsPerAddr = 10 // confidence is the absolute difference between the number of successes and failures for an address // targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval` - // before probing again + // before probing again. targetConfidence = 3 // minConfidence is the confidence threshold for an address to be considered reachable or unreachable // confidence is the absolute difference between the number of successes and failures for an address @@ -368,10 +348,10 @@ const ( // and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to // targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval. maxRecentDialsWindow = targetConfidence + 2 - // maxProbeInterval is the maximum interval between probes for an address - maxProbeInterval = 1 * time.Hour + // highConfidenceAddrProbeInterval is the maximum interval between probes for an address + highConfidenceAddrProbeInterval = 1 * time.Hour // maxProbeResultTTL is the maximum time to keep probe results for an address - maxProbeResultTTL = maxRecentDialsWindow * maxProbeInterval + maxProbeResultTTL = maxRecentDialsWindow * highConfidenceAddrProbeInterval ) // probeManager tracks reachability for a set of addresses by periodically probing reachability with autonatv2. @@ -406,8 +386,8 @@ func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiadd for _, a := range m.addrs { s := m.statuses[string(a.Bytes())] - s.outcomes.RemoveBefore(m.now().Add(-maxProbeResultTTL)) // cleanup stale results - switch s.outcomes.Reachability() { + s.RemoveBefore(m.now().Add(-maxProbeResultTTL)) // cleanup stale results + switch s.Reachability() { case network.ReachabilityPublic: reachable = append(reachable, a) case network.ReachabilityPrivate: @@ -423,23 +403,20 @@ func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) { defer m.mx.Unlock() slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) - + statuses := make(map[string]*addrStatus, len(addrs)) for _, addr := range addrs { k := string(addr.Bytes()) if _, ok := m.statuses[k]; !ok { - m.statuses[k] = &addrStatus{Addr: addr, outcomes: addrOutcomes{}} - } - } - for k, s := range m.statuses { - _, ok := slices.BinarySearchFunc(addrs, s.Addr, func(a, b ma.Multiaddr) int { return a.Compare(b) }) - if !ok { - delete(m.statuses, k) + statuses[k] = &addrStatus{Addr: addr} + } else { + statuses[k] = m.statuses[k] } } m.addrs = addrs + m.statuses = statuses } -// GetProbe returns the next probe. Returns empty slice in case there are no more probes. +// GetProbe returns the next probe. Returns zero value in case there are no more probes. // Probes that are run against an autonatv2 client should be marked in progress with // `MarkProbeInProgress` before running. func (m *probeManager) GetProbe() []autonatv2.Request { @@ -449,7 +426,7 @@ func (m *probeManager) GetProbe() []autonatv2.Request { now := m.now() for i, a := range m.addrs { ab := a.Bytes() - pc := m.requiredProbeCount(m.statuses[string(ab)], now) + pc := m.statuses[string(ab)].RequiredProbeCount(now) if pc == 0 { continue } @@ -464,7 +441,7 @@ func (m *probeManager) GetProbe() []autonatv2.Request { for j := 1; j < len(m.addrs); j++ { k := (i + j) % len(m.addrs) ab := m.addrs[k].Bytes() - pc := m.requiredProbeCount(m.statuses[string(ab)], now) + pc := m.statuses[string(ab)].RequiredProbeCount(now) if pc == 0 { continue } @@ -523,42 +500,45 @@ func (m *probeManager) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Res } // Consider only primary address as refused. This increases the number of - // probes are refused, but refused probes are cheap as no dial is - // made by the server. + // refused probes, but refused probes are cheap for a server as no dials are made. if res.AllAddrsRefused { if s, ok := m.statuses[primaryAddrKey]; ok { - m.addRefusal(s, now) + s.AddRefusal(now) } return } dialAddrKey := string(res.Addr.Bytes()) if dialAddrKey != primaryAddrKey { if s, ok := m.statuses[primaryAddrKey]; ok { - m.addRefusal(s, now) + s.AddRefusal(now) } } - // record the result for the dialled address - expireBefore := now.Add(-maxProbeInterval) + // record the result for the dialed address if s, ok := m.statuses[dialAddrKey]; ok { - m.addDial(s, now, res.Reachability, expireBefore) + s.AddOutcome(now, res.Reachability, maxRecentDialsWindow) } } -func (*probeManager) addRefusal(s *addrStatus, now time.Time) { - s.lastRefusalTime = now - s.consecutiveRefusals++ +type dialOutcome struct { + Success bool + At time.Time } -func (*probeManager) addDial(s *addrStatus, now time.Time, rch network.Reachability, expireBefore time.Time) { - s.lastRefusalTime = time.Time{} - s.consecutiveRefusals = 0 - s.dialTimes = append(s.dialTimes, now) - s.outcomes.AddOutcome(now, rch, maxRecentDialsWindow) - s.outcomes.RemoveBefore(expireBefore) +type addrStatus struct { + Addr ma.Multiaddr + lastRefusalTime time.Time + consecutiveRefusals int + dialTimes []time.Time + outcomes []dialOutcome } -func (m *probeManager) requiredProbeCount(s *addrStatus, now time.Time) int { +func (s *addrStatus) Reachability() network.Reachability { + rch, _, _ := s.reachabilityAndCounts() + return rch +} + +func (s *addrStatus) RequiredProbeCount(now time.Time) int { if s.consecutiveRefusals >= maxConsecutiveRefusals { if now.Sub(s.lastRefusalTime) < recentProbeInterval { return 0 @@ -569,49 +549,16 @@ func (m *probeManager) requiredProbeCount(s *addrStatus, now time.Time) int { } // Don't probe if we have probed too many times recently - rd := m.recentDialCount(s, now) + rd := s.recentDialCount(now) if rd >= maxRecentDialsPerAddr { return 0 } - return s.outcomes.RequiredProbeCount(now) -} - -func (*probeManager) recentDialCount(s *addrStatus, now time.Time) int { - cnt := 0 - for _, t := range slices.Backward(s.dialTimes) { - if now.Sub(t) > recentProbeInterval { - break - } - cnt++ - } - return cnt -} - -type dialOutcome struct { - Success bool - At time.Time -} - -type addrStatus struct { - Addr ma.Multiaddr - lastRefusalTime time.Time - consecutiveRefusals int - dialTimes []time.Time - outcomes addrOutcomes -} - -type addrOutcomes struct { - outcomes []dialOutcome + return s.requiredProbeCountForConfirmation(now) } -func (o *addrOutcomes) Reachability() network.Reachability { - rch, _, _ := o.reachabilityAndCounts() - return rch -} - -func (o *addrOutcomes) RequiredProbeCount(now time.Time) int { - reachability, successes, failures := o.reachabilityAndCounts() +func (s *addrStatus) requiredProbeCountForConfirmation(now time.Time) int { + reachability, successes, failures := s.reachabilityAndCounts() confidence := successes - failures if confidence < 0 { confidence = -confidence @@ -623,12 +570,12 @@ func (o *addrOutcomes) RequiredProbeCount(now time.Time) int { // we have enough confirmations; check if we should refresh // Should never happen. The confidence logic above should require a few probes. - if len(o.outcomes) == 0 { + if len(s.outcomes) == 0 { return 0 } - lastOutcome := o.outcomes[len(o.outcomes)-1] + lastOutcome := s.outcomes[len(s.outcomes)-1] // If the last probe result is old, we need to retest - if now.Sub(lastOutcome.At) > maxProbeInterval { + if now.Sub(lastOutcome.At) > highConfidenceAddrProbeInterval { return 1 } // if the last probe result was different from reachability, probe again. @@ -648,7 +595,24 @@ func (o *addrOutcomes) RequiredProbeCount(now time.Time) int { return 0 } -func (o *addrOutcomes) AddOutcome(at time.Time, rch network.Reachability, windowSize int) { +func (s *addrStatus) AddRefusal(now time.Time) { + s.lastRefusalTime = now + s.consecutiveRefusals++ +} + +func (s *addrStatus) AddOutcome(at time.Time, rch network.Reachability, windowSize int) { + s.lastRefusalTime = time.Time{} + s.consecutiveRefusals = 0 + + s.dialTimes = append(s.dialTimes, at) + for i, t := range s.dialTimes { + if at.Sub(t) < recentProbeInterval { + s.dialTimes = slices.Delete(s.dialTimes, 0, i) + break + } + } + + s.RemoveBefore(at.Add(-maxProbeResultTTL)) // remove old outcomes success := false switch rch { case network.ReachabilityPublic: @@ -658,25 +622,36 @@ func (o *addrOutcomes) AddOutcome(at time.Time, rch network.Reachability, window default: return // don't store the outcome if reachability is unknown } - o.outcomes = append(o.outcomes, dialOutcome{At: at, Success: success}) - if len(o.outcomes) > windowSize { - o.outcomes = slices.Delete(o.outcomes, 0, len(o.outcomes)-windowSize) + s.outcomes = append(s.outcomes, dialOutcome{At: at, Success: success}) + if len(s.outcomes) > windowSize { + s.outcomes = slices.Delete(s.outcomes, 0, len(s.outcomes)-windowSize) } } // RemoveBefore removes outcomes before t -func (o *addrOutcomes) RemoveBefore(t time.Time) { +func (s *addrStatus) RemoveBefore(t time.Time) { var end = 0 - for ; end < len(o.outcomes); end++ { - if !o.outcomes[end].At.Before(t) { + for ; end < len(s.outcomes); end++ { + if !s.outcomes[end].At.Before(t) { break } } - o.outcomes = slices.Delete(o.outcomes, 0, end) + s.outcomes = slices.Delete(s.outcomes, 0, end) +} + +func (s *addrStatus) recentDialCount(now time.Time) int { + cnt := 0 + for _, t := range slices.Backward(s.dialTimes) { + if now.Sub(t) > recentProbeInterval { + break + } + cnt++ + } + return cnt } -func (o *addrOutcomes) reachabilityAndCounts() (rch network.Reachability, successes int, failures int) { - for _, r := range o.outcomes { +func (s *addrStatus) reachabilityAndCounts() (rch network.Reachability, successes int, failures int) { + for _, r := range s.outcomes { if r.Success { successes++ } else { diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index 74c341c7c6..84223291de 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -125,7 +125,7 @@ func TestProbeManager(t *testing.T) { reqs := pm.GetProbe() require.Empty(t, reqs) - cl.Add(maxProbeInterval + time.Millisecond) + cl.Add(highConfidenceAddrProbeInterval + time.Millisecond) reqs = nextProbe(pm) require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) reqs = nextProbe(pm) @@ -239,7 +239,7 @@ func TestAddrsReachabilityTracker(t *testing.T) { t.Run("simple", func(t *testing.T) { // pub1 reachable, pub2 unreachable, pub3 ignored mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for i, req := range reqs { if req.Addr.Equal(pub1) { return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil @@ -274,7 +274,7 @@ func TestAddrsReachabilityTracker(t *testing.T) { t.Run("confirmed addrs ordering", func(t *testing.T) { mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil }, } @@ -314,7 +314,7 @@ func TestAddrsReachabilityTracker(t *testing.T) { var allow atomic.Bool mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { select { case notify <- struct{}{}: default: @@ -378,7 +378,7 @@ func TestAddrsReachabilityTracker(t *testing.T) { called := make(chan struct{}, minConfidence) notify := make(chan struct{}) mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { select { case called <- struct{}{}: notify <- struct{}{} @@ -438,7 +438,7 @@ func TestAddrsReachabilityTracker(t *testing.T) { } mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { select { case notify <- struct{}{}: default: @@ -461,14 +461,14 @@ func TestAddrsReachabilityTracker(t *testing.T) { time.Sleep(100 * time.Millisecond) require.True(t, drainNotify()) // check that we did receive probes - cl.Add(maxProbeInterval / 2) + cl.Add(highConfidenceAddrProbeInterval / 2) select { case <-notify: t.Fatal("unexpected call") case <-time.After(50 * time.Millisecond): } - cl.Add(maxProbeInterval/2 + defaultResetInterval) // defaultResetInterval for the next probe time + cl.Add(highConfidenceAddrProbeInterval/2 + defaultReachabilityRefreshInterval) // defaultResetInterval for the next probe time select { case <-notify: case <-time.After(1 * time.Second): @@ -484,7 +484,7 @@ func TestRunProbes(t *testing.T) { defer cancel() t.Run("backoff on ErrNoValidPeers", func(t *testing.T) { mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { return autonatv2.Result{}, autonatv2.ErrNoPeers }, } @@ -499,7 +499,7 @@ func TestRunProbes(t *testing.T) { t.Run("returns backoff on errTooManyConsecutiveFailures", func(t *testing.T) { // Create a client that always returns ErrDialRefused mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { return autonatv2.Result{}, errors.New("test error") }, } @@ -516,7 +516,7 @@ func TestRunProbes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) block := make(chan struct{}) mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { block <- struct{}{} return autonatv2.Result{}, nil }, @@ -559,7 +559,7 @@ func TestRunProbes(t *testing.T) { addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for i, req := range reqs { if req.Addr.Equal(pub1) { return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil @@ -583,7 +583,7 @@ func TestRunProbes(t *testing.T) { addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) mockClient := mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for i, req := range reqs { if req.Addr.Equal(pub1) { return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil @@ -606,34 +606,44 @@ func TestRunProbes(t *testing.T) { }) } -func TestDialOutcome(t *testing.T) { +func TestAddrStatusProbeCount(t *testing.T) { cases := []struct { inputs string wantRequiredProbes int wantReachability network.Reachability }{ { - inputs: "SSSSSSSSSSS", - wantRequiredProbes: 0, - wantReachability: network.ReachabilityPublic, + inputs: "", + wantRequiredProbes: 3, + wantReachability: network.ReachabilityUnknown, }, { - inputs: "SSSSSSSSSSF", + inputs: "S", + wantRequiredProbes: 2, + wantReachability: network.ReachabilityUnknown, + }, + { + inputs: "SS", wantRequiredProbes: 1, wantReachability: network.ReachabilityPublic, }, { - inputs: "SFSFSFSFSSSS", + inputs: "SSS", wantRequiredProbes: 0, wantReachability: network.ReachabilityPublic, }, { - inputs: "SSSSSSSSSFSF", - wantRequiredProbes: 2, - wantReachability: network.ReachabilityUnknown, + inputs: "SSSSSSSF", + wantRequiredProbes: 1, + wantReachability: network.ReachabilityPublic, }, { - inputs: "S", + inputs: "SFSFSSSS", + wantRequiredProbes: 0, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SSSSSFSF", wantRequiredProbes: 2, wantReachability: network.ReachabilityUnknown, }, @@ -646,7 +656,7 @@ func TestDialOutcome(t *testing.T) { for _, c := range cases { t.Run(c.inputs, func(t *testing.T) { now := time.Time{}.Add(1 * time.Second) - ao := addrOutcomes{} + ao := addrStatus{} for _, r := range c.inputs { if r == 'S' { ao.AddOutcome(now, network.ReachabilityPublic, 5) @@ -658,7 +668,7 @@ func TestDialOutcome(t *testing.T) { require.Equal(t, ao.RequiredProbeCount(now), c.wantRequiredProbes) require.Equal(t, ao.Reachability(), c.wantReachability) if c.wantRequiredProbes == 0 { - now = now.Add(maxProbeInterval + 10*time.Microsecond) + now = now.Add(highConfidenceAddrProbeInterval + 10*time.Microsecond) require.Equal(t, ao.RequiredProbeCount(now), 1) } @@ -673,9 +683,9 @@ func BenchmarkAddrTracker(b *testing.B) { cl := clock.NewMock() t := newProbeManager(cl.Now) - var addrs []ma.Multiaddr - for range 20 { - addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", rand.Intn(1000)))) + addrs := make([]ma.Multiaddr, 20) + for i := range addrs { + addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", rand.Intn(1000))) } t.UpdateAddrs(addrs) b.ReportAllocs() @@ -700,7 +710,7 @@ func FuzzAddrsReachabilityTracker(f *testing.F) { newMockClient := func(b []byte) mockAutoNATClient { count := 0 return mockAutoNATClient{ - F: func(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { if len(b) == 0 { return autonatv2.Result{}, nil } @@ -782,7 +792,7 @@ func FuzzAddrsReachabilityTracker(f *testing.F) { ipType = int(ips[0]) } ips = ips[1:] - var x, y int64 = 0, 0 + var x, y int64 split := 128 / 8 if len(ips) < split { split = len(ips) diff --git a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/4f31b7942ec62406 b/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/4f31b7942ec62406 deleted file mode 100644 index 2badd5771a..0000000000 --- a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/4f31b7942ec62406 +++ /dev/null @@ -1,6 +0,0 @@ -go test fuzz v1 -int(93) -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("") diff --git a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/53e52cff547ff885 b/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/53e52cff547ff885 deleted file mode 100644 index 95aa2608e7..0000000000 --- a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/53e52cff547ff885 +++ /dev/null @@ -1,6 +0,0 @@ -go test fuzz v1 -int(0) -[]byte("") -[]byte("") -[]byte("0") -[]byte("") diff --git a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/79485637e486f9db b/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/79485637e486f9db deleted file mode 100644 index 1f73c2a028..0000000000 --- a/p2p/host/basic/testdata/fuzz/FuzzAddrsReachabilityTracker/79485637e486f9db +++ /dev/null @@ -1,6 +0,0 @@ -go test fuzz v1 -int(0) -[]byte("0") -[]byte("\xc1\xd6,\x9e.]") -[]byte("0") -[]byte("") diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index 26870976fb..d03bc66b8f 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -81,9 +81,11 @@ type AutoNAT struct { srv *server cli *client - mx sync.Mutex - peers *peersMap - throttlePeer map[peer.ID]time.Time + mx sync.Mutex + peers *peersMap + throttlePeer map[peer.ID]time.Time + // throttlePeerDuration is the duration to wait before making another dial request to the + // same server. throttlePeerDuration time.Duration // allowPrivateAddrs enables using private and localhost addresses for reachability checks. // This is only useful for testing. @@ -208,7 +210,7 @@ func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, if p == "" { return Result{}, ErrNoPeers } - res, err := an.cli.GetReachability(ctx, p, reqs) + res, err := an.cli.GetReachability(ctx, p, filteredReqs) if err != nil { log.Debugf("reachability check with %s failed, err: %s", p, err) return res, fmt.Errorf("reachability check with %s failed: %w", p, err) @@ -258,7 +260,7 @@ func (p *peersMap) Shuffled() iter.Seq[peer.ID] { n := len(p.peers) start := 0 if n > 0 { - start = rand.IntN(len(p.peers)) + start = rand.IntN(n) } return func(yield func(peer.ID) bool) { for i := range n { diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 5811105585..4cbfde4c61 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -291,6 +291,49 @@ func TestClientDataRequest(t *testing.T) { } } +func TestAutoNATPrivateAndPublicAddrs(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer dialerHost.Close() + handler := func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_E_DIAL_ERROR, + AddrIdx: 0, + }, + }, + }) + s.Close() + } + + b.SetStreamHandler(DialProtocol, handler) + privateAddr := ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1") + publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/10/quic-v1") + res, err := an.GetReachability(context.Background(), + []Request{ + {Addr: privateAddr}, + {Addr: publicAddr}, + }) + require.NoError(t, err) + require.Equal(t, res.Addr, publicAddr, "%s\n%s", res.Addr, publicAddr) + require.Equal(t, res.Idx, 1) + require.Equal(t, res.Reachability, network.ReachabilityPrivate) +} + func TestClientDialBacks(t *testing.T) { an := newAutoNAT(t, nil, allowPrivateAddrs) defer an.Close() @@ -661,7 +704,7 @@ func FuzzClient(f *testing.F) { ipType = int(ips[0]) } ips = ips[1:] - var x, y int64 = 0, 0 + var x, y int64 split := 128 / 8 if len(ips) < split { split = len(ips) @@ -751,7 +794,7 @@ func FuzzClient(f *testing.F) { return addrs } // reduce the streamTimeout before running this. TODO: fix this - f.Fuzz(func(t *testing.T, numAddrs int, ips, protos, hostNames []byte) { + f.Fuzz(func(_ *testing.T, numAddrs int, ips, protos, hostNames []byte) { addrs := getAddrs(numAddrs, ips, protos, hostNames) reqs := make([]Request, len(addrs)) for i, addr := range addrs { diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index 15b454e35f..7cd0dba5f0 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -109,7 +109,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request break // provide dial data if appropriate case msg.GetDialDataRequest() != nil: - if err := ac.validateDialDataRequest(reqs, &msg); err != nil { + if err := validateDialDataRequest(reqs, &msg); err != nil { s.Reset() return Result{}, fmt.Errorf("invalid dial data request: %s %w", s.Conn().RemoteMultiaddr(), err) } @@ -162,7 +162,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request return ac.newResult(resp, reqs, dialBackAddr) } -func (*client) validateDialDataRequest(reqs []Request, msg *pb.Message) error { +func validateDialDataRequest(reqs []Request, msg *pb.Message) error { idx := int(msg.GetDialDataRequest().AddrIdx) if idx >= len(reqs) { // invalid address index return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs)) @@ -184,7 +184,7 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr } addr := reqs[idx].Addr - rch := network.ReachabilityUnknown + rch := network.ReachabilityUnknown //nolint:ineffassign switch resp.DialStatus { case pb.DialStatus_OK: if !ac.areAddrsConsistent(dialBackAddr, addr) { From f6f9d3796da81894a6d117ffd49400eb154338ae Mon Sep 17 00:00:00 2001 From: guillaumemichel Date: Thu, 22 May 2025 10:45:37 +0200 Subject: [PATCH 14/15] use go-clock --- p2p/host/basic/addrs_reachability_tracker.go | 4 ++-- p2p/host/basic/addrs_reachability_tracker_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index 0dc89276a0..e2a73980dd 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - "github.com/benbjohnson/clock" + "github.com/filecoin-project/go-clock" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" ma "github.com/multiformats/go-multiaddr" @@ -630,7 +630,7 @@ func (s *addrStatus) AddOutcome(at time.Time, rch network.Reachability, windowSi // RemoveBefore removes outcomes before t func (s *addrStatus) RemoveBefore(t time.Time) { - var end = 0 + end := 0 for ; end < len(s.outcomes); end++ { if !s.outcomes[end].At.Before(t) { break diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index 84223291de..f131e40380 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -15,7 +15,7 @@ import ( "testing" "time" - "github.com/benbjohnson/clock" + "github.com/filecoin-project/go-clock" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" ma "github.com/multiformats/go-multiaddr" From 2d3c76b620e9d363a1bb031c7f3bb7089f22182b Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 1 Jun 2025 12:51:20 +0530 Subject: [PATCH 15/15] more review comments --- p2p/host/basic/addrs_manager.go | 10 +- p2p/host/basic/addrs_reachability_tracker.go | 136 +++++++++--------- .../basic/addrs_reachability_tracker_test.go | 82 ++++++----- p2p/protocol/autonatv2/autonat.go | 10 +- p2p/protocol/autonatv2/autonat_test.go | 26 ++++ 5 files changed, 158 insertions(+), 106 deletions(-) diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index 6c870f43cc..217570ae86 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -294,6 +294,14 @@ func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, curre } } + // We *must* send both reachability changed and addrs changed events from the + // same goroutine to ensure correct ordering + // Consider the events: + // - addr x discovered + // - addr x is reachable + // - addr x removed + // We must send these events in the same order. It'll be confusing for consumers + // if the reachable event is received after the addr removed event. if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) || areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) { log.Debugf("host reachable addrs updated: %s", current.localAddrs) @@ -676,7 +684,7 @@ func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr { j := 0 // mark entries not in source as nil for i, a := range addrs { - // move right till a is greater + // move right as long as a > source[j] for j < len(source) && a.Compare(source[j]) > 0 { j++ } diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go index 0dc89276a0..80c769e18c 100644 --- a/p2p/host/basic/addrs_reachability_tracker.go +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -42,12 +42,12 @@ type addrsReachabilityTracker struct { cancel context.CancelFunc wg sync.WaitGroup - cli autonatv2Client + client autonatv2Client // reachabilityUpdateCh is used to notify when reachability may have changed reachabilityUpdateCh chan struct{} maxConcurrency int newAddrsProbeDelay time.Duration - addrTracker *probeManager + probeManager *probeManager newAddrs chan []ma.Multiaddr clock clock.Clock @@ -66,9 +66,9 @@ func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh ch return &addrsReachabilityTracker{ ctx: ctx, cancel: cancel, - cli: client, + client: client, reachabilityUpdateCh: reachabilityUpdateCh, - addrTracker: newProbeManager(cl.Now), + probeManager: newProbeManager(cl.Now), newAddrsProbeDelay: newAddrsProbeDelay, maxConcurrency: defaultMaxConcurrency, newAddrs: make(chan []ma.Multiaddr, 1), @@ -134,15 +134,15 @@ func (r *addrsReachabilityTracker) background() { select { case <-probeTicker.C: // don't start a probe if we have a scheduled probe - if task.RespCh == nil && nextProbeTime.IsZero() { + if task.BackoffCh == nil && nextProbeTime.IsZero() { task = r.refreshReachability() } case <-probeTimer.C: - if task.RespCh == nil { + if task.BackoffCh == nil { task = r.refreshReachability() } nextProbeTime = time.Time{} - case backoff := <-task.RespCh: + case backoff := <-task.BackoffCh: task = reachabilityTask{} // On completion, start the next probe immediately, or wait for backoff. // In case there are no further probes, the reachability tracker will return an empty task, @@ -154,9 +154,9 @@ func (r *addrsReachabilityTracker) background() { } nextProbeTime = r.clock.Now().Add(backoffInterval) case addrs := <-r.newAddrs: - if task.RespCh != nil { // cancel running task. + if task.BackoffCh != nil { // cancel running task. task.Cancel() - <-task.RespCh // ignore backoff from cancelled task + <-task.BackoffCh // ignore backoff from cancelled task task = reachabilityTask{} } r.updateTrackedAddrs(addrs) @@ -165,9 +165,9 @@ func (r *addrsReachabilityTracker) background() { nextProbeTime = newAddrsNextTime } case <-r.ctx.Done(): - if task.RespCh != nil { + if task.BackoffCh != nil { task.Cancel() - <-task.RespCh + <-task.BackoffCh task = reachabilityTask{} } return @@ -197,7 +197,7 @@ func newBackoffInterval(current time.Duration) time.Duration { } func (r *addrsReachabilityTracker) appendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { - reachable, unreachable = r.addrTracker.AppendConfirmedAddrs(reachable, unreachable) + reachable, unreachable = r.probeManager.AppendConfirmedAddrs(reachable, unreachable) r.mx.Lock() r.reachableAddrs = append(r.reachableAddrs[:0], reachable...) r.unreachableAddrs = append(r.unreachableAddrs[:0], unreachable...) @@ -220,30 +220,69 @@ func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) { log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs) addrs = addrs[:maxTrackedAddrs] } - r.addrTracker.UpdateAddrs(addrs) + r.probeManager.UpdateAddrs(addrs) } +type probe = []autonatv2.Request + +const probeTimeout = 30 * time.Second + // reachabilityTask is a task to refresh reachability. // Waiting on the zero value blocks forever. type reachabilityTask struct { Cancel context.CancelFunc - RespCh chan bool + // BackoffCh returns whether the caller should backoff before + // refreshing reachability + BackoffCh chan bool } func (r *addrsReachabilityTracker) refreshReachability() reachabilityTask { - if len(r.addrTracker.GetProbe()) == 0 { + if len(r.probeManager.GetProbe()) == 0 { return reachabilityTask{} } resCh := make(chan bool, 1) ctx, cancel := context.WithTimeout(r.ctx, 5*time.Minute) r.wg.Add(1) + // We run probes provided by addrsTracker. It stops probing when any + // of the following happens: + // - there are no more probes to run + // - context is completed + // - there are too many consecutive failures from the client + // - the client has no valid peers to probe go func() { defer r.wg.Done() defer cancel() - backoff := runProbes(ctx, r.maxConcurrency, r.addrTracker, r.cli) - resCh <- backoff + client := &errCountingClient{autonatv2Client: r.client, MaxConsecutiveErrors: maxConsecutiveErrors} + var backoff atomic.Bool + var wg sync.WaitGroup + wg.Add(r.maxConcurrency) + for range r.maxConcurrency { + go func() { + defer wg.Done() + for { + if ctx.Err() != nil { + return + } + reqs := r.probeManager.GetProbe() + if len(reqs) == 0 { + return + } + r.probeManager.MarkProbeInProgress(reqs) + rctx, cancel := context.WithTimeout(ctx, probeTimeout) + res, err := client.GetReachability(rctx, reqs) + cancel() + r.probeManager.CompleteProbe(reqs, res, err) + if isErrorPersistent(err) { + backoff.Store(true) + return + } + } + }() + } + wg.Wait() + resCh <- backoff.Load() }() - return reachabilityTask{Cancel: cancel, RespCh: resCh} + return reachabilityTask{Cancel: cancel, BackoffCh: resCh} } var errTooManyConsecutiveFailures = errors.New("too many consecutive failures") @@ -257,7 +296,7 @@ type errCountingClient struct { consecutiveErrors int } -func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { +func (c *errCountingClient) GetReachability(ctx context.Context, reqs probe) (autonatv2.Result, error) { res, err := c.autonatv2Client.GetReachability(ctx, reqs) c.mx.Lock() defer c.mx.Unlock() @@ -277,44 +316,6 @@ func (c *errCountingClient) GetReachability(ctx context.Context, reqs []autonatv const maxConsecutiveErrors = 20 -// runProbes runs probes provided by addrsTracker with the given client. It returns true if the caller should -// backoff before retrying probes. It stops probing when any of the following happens: -// - there are no more probes to run -// - context is completed -// - there are too many consecutive failures from the client -// - the client has no valid peers to probe -func runProbes(ctx context.Context, concurrency int, addrsTracker *probeManager, client autonatv2Client) bool { - client = &errCountingClient{autonatv2Client: client, MaxConsecutiveErrors: maxConsecutiveErrors} - var backoff atomic.Bool - var wg sync.WaitGroup - wg.Add(concurrency) - for range concurrency { - go func() { - defer wg.Done() - for { - if ctx.Err() != nil { - return - } - reqs := addrsTracker.GetProbe() - if len(reqs) == 0 { - return - } - rctx, cancel := context.WithTimeout(ctx, 30*time.Second) - addrsTracker.MarkProbeInProgress(reqs) - res, err := client.GetReachability(rctx, reqs) - cancel() - addrsTracker.CompleteProbe(reqs, res, err) - if isErrorPersistent(err) { - backoff.Store(true) - return - } - } - }() - } - wg.Wait() - return backoff.Load() -} - // isErrorPersistent returns whether the error will repeat on future probes for a while func isErrorPersistent(err error) bool { if err == nil { @@ -419,29 +420,26 @@ func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) { // GetProbe returns the next probe. Returns zero value in case there are no more probes. // Probes that are run against an autonatv2 client should be marked in progress with // `MarkProbeInProgress` before running. -func (m *probeManager) GetProbe() []autonatv2.Request { +func (m *probeManager) GetProbe() probe { m.mx.Lock() defer m.mx.Unlock() now := m.now() for i, a := range m.addrs { - ab := a.Bytes() - pc := m.statuses[string(ab)].RequiredProbeCount(now) - if pc == 0 { - continue - } - if m.inProgressProbes[string(ab)] >= pc { + ab := string(a.Bytes()) + pc := m.statuses[ab].RequiredProbeCount(now) + if m.inProgressProbes[ab] >= pc { continue } - reqs := make([]autonatv2.Request, 0, maxAddrsPerRequest) + reqs := make(probe, 0, maxAddrsPerRequest) reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true}) // We have the first(primary) address. Append other addresses, ignoring inprogress probes // on secondary addresses. The expectation is that the primary address will // be dialed. for j := 1; j < len(m.addrs); j++ { k := (i + j) % len(m.addrs) - ab := m.addrs[k].Bytes() - pc := m.statuses[string(ab)].RequiredProbeCount(now) + ab := string(m.addrs[k].Bytes()) + pc := m.statuses[ab].RequiredProbeCount(now) if pc == 0 { continue } @@ -457,7 +455,7 @@ func (m *probeManager) GetProbe() []autonatv2.Request { // MarkProbeInProgress should be called when a probe is started. // All in progress probes *MUST* be completed with `CompleteProbe` -func (m *probeManager) MarkProbeInProgress(reqs []autonatv2.Request) { +func (m *probeManager) MarkProbeInProgress(reqs probe) { if len(reqs) == 0 { return } @@ -475,7 +473,7 @@ func (m *probeManager) InProgressProbes() int { } // CompleteProbe should be called when a probe completes. -func (m *probeManager) CompleteProbe(reqs []autonatv2.Request, res autonatv2.Result, err error) { +func (m *probeManager) CompleteProbe(reqs probe, res autonatv2.Result, err error) { now := m.now() if len(reqs) == 0 { diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go index 84223291de..a58b60db48 100644 --- a/p2p/host/basic/addrs_reachability_tracker_test.go +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -219,12 +219,12 @@ func TestAddrsReachabilityTracker(t *testing.T) { tr := &addrsReachabilityTracker{ ctx: ctx, cancel: cancel, - cli: cli, + client: cli, newAddrs: make(chan []ma.Multiaddr, 1), reachabilityUpdateCh: make(chan struct{}, 1), maxConcurrency: 3, newAddrsProbeDelay: 0 * time.Second, - addrTracker: newProbeManager(cl.Now), + probeManager: newProbeManager(cl.Now), clock: cl, } err := tr.Start() @@ -460,7 +460,6 @@ func TestAddrsReachabilityTracker(t *testing.T) { cl.Add(1) time.Sleep(100 * time.Millisecond) require.True(t, drainNotify()) // check that we did receive probes - cl.Add(highConfidenceAddrProbeInterval / 2) select { case <-notify: @@ -477,11 +476,22 @@ func TestAddrsReachabilityTracker(t *testing.T) { }) } -func TestRunProbes(t *testing.T) { +func TestRefreshReachability(t *testing.T) { pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") pub2 := ma.StringCast("/ip4/1.1.1.1/tcp/2") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + + newTracker := func(client autonatv2Client, pm *probeManager) *addrsReachabilityTracker { + return &addrsReachabilityTracker{ + probeManager: pm, + client: client, + clock: clock.New(), + maxConcurrency: 3, + ctx: ctx, + cancel: cancel, + } + } t.Run("backoff on ErrNoValidPeers", func(t *testing.T) { mockClient := mockAutoNATClient{ F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { @@ -491,8 +501,9 @@ func TestRunProbes(t *testing.T) { addrTracker := newProbeManager(time.Now) addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) - result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) - require.True(t, result) + r := newTracker(mockClient, addrTracker) + res := r.refreshReachability() + require.True(t, <-res.BackoffCh) require.Equal(t, addrTracker.InProgressProbes(), 0) }) @@ -504,12 +515,12 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now) - addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) - - result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) - require.True(t, result) - require.Equal(t, addrTracker.InProgressProbes(), 0) + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub1}) + r := newTracker(mockClient, pm) + result := r.refreshReachability() + require.True(t, <-result.BackoffCh) + require.Equal(t, pm.InProgressProbes(), 0) }) t.Run("quits on cancellation", func(t *testing.T) { @@ -522,15 +533,22 @@ func TestRunProbes(t *testing.T) { }, } - addrTracker := newProbeManager(time.Now) - addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub1}) + r := &addrsReachabilityTracker{ + ctx: ctx, + cancel: cancel, + client: mockClient, + probeManager: pm, + clock: clock.New(), + } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) - assert.False(t, result) - assert.Equal(t, addrTracker.InProgressProbes(), 0) + result := r.refreshReachability() + assert.False(t, <-result.BackoffCh) + assert.Equal(t, pm.InProgressProbes(), 0) }() cancel() @@ -555,9 +573,6 @@ func TestRunProbes(t *testing.T) { t.Run("handles refusals", func(t *testing.T) { pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") - addrTracker := newProbeManager(time.Now) - addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) - mockClient := mockAutoNATClient{ F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for i, req := range reqs { @@ -569,19 +584,20 @@ func TestRunProbes(t *testing.T) { }, } - result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) - require.False(t, result) + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) + r := newTracker(mockClient, pm) + + result := r.refreshReachability() + require.False(t, <-result.BackoffCh) - reachable, unreachable := addrTracker.AppendConfirmedAddrs(nil, nil) + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) require.Equal(t, reachable, []ma.Multiaddr{pub1}) require.Empty(t, unreachable) - require.Equal(t, addrTracker.InProgressProbes(), 0) + require.Equal(t, pm.InProgressProbes(), 0) }) t.Run("handles completions", func(t *testing.T) { - addrTracker := newProbeManager(time.Now) - addrTracker.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) - mockClient := mockAutoNATClient{ F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { for i, req := range reqs { @@ -595,14 +611,16 @@ func TestRunProbes(t *testing.T) { return autonatv2.Result{AllAddrsRefused: true}, nil }, } + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) + r := newTracker(mockClient, pm) + result := r.refreshReachability() + require.False(t, <-result.BackoffCh) - result := runProbes(ctx, defaultMaxConcurrency, addrTracker, mockClient) - require.False(t, result) - - reachable, unreachable := addrTracker.AppendConfirmedAddrs(nil, nil) + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) require.Equal(t, reachable, []ma.Multiaddr{pub1}) require.Equal(t, unreachable, []ma.Multiaddr{pub2}) - require.Equal(t, addrTracker.InProgressProbes(), 0) + require.Equal(t, pm.InProgressProbes(), 0) }) } diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index d03bc66b8f..866c02381f 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -284,9 +284,11 @@ func (p *peersMap) Delete(id peer.ID) { if !ok { return } + n := len(p.peers) + lastPeer := p.peers[n-1] + p.peers[idx] = lastPeer + p.peerIdx[lastPeer] = idx + p.peers[n-1] = "" + p.peers = p.peers[:n-1] delete(p.peerIdx, id) - p.peers[idx] = p.peers[len(p.peers)-1] - p.peerIdx[p.peers[idx]] = idx - p.peers[len(p.peers)-1] = "" - p.peers = p.peers[:len(p.peers)-1] } diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 4cbfde4c61..df79d52c4e 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -660,6 +660,32 @@ func TestAreAddrsConsistency(t *testing.T) { } } +func TestPeerMap(t *testing.T) { + pm := newPeersMap() + // Add 1, 2, 3 + pm.Put(peer.ID("1")) + pm.Put(peer.ID("2")) + pm.Put(peer.ID("3")) + + // Remove 3, 2 + pm.Delete(peer.ID("3")) + pm.Delete(peer.ID("2")) + + // Add 4 + pm.Put(peer.ID("4")) + + // Remove 3, 2 again. Should be no op + pm.Delete(peer.ID("3")) + pm.Delete(peer.ID("2")) + + contains := []peer.ID{"1", "4"} + elems := make([]peer.ID, 0) + for p := range pm.Shuffled() { + elems = append(elems, p) + } + require.ElementsMatch(t, contains, elems) +} + func FuzzClient(f *testing.F) { a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2)) c := newAutoNAT(f, nil)