diff --git a/.golangci.yml b/.golangci.yml index 645266d523..eac7bfe3fd 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -8,6 +8,8 @@ linters: - revive - unused - prealloc + disable: + - errcheck settings: revive: diff --git a/config/config.go b/config/config.go index 607dceba49..836915b7aa 100644 --- a/config/config.go +++ b/config/config.go @@ -33,6 +33,7 @@ import ( routed "github.com/libp2p/go-libp2p/p2p/host/routed" "github.com/libp2p/go-libp2p/p2p/net/swarm" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" @@ -413,15 +414,7 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { return fxopts, nil } -func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) { - var autonatv2Dialer host.Host - if cfg.EnableAutoNATv2 { - ah, err := cfg.makeAutoNATV2Host() - if err != nil { - return nil, err - } - autonatv2Dialer = ah - } +func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus, an *autonatv2.AutoNAT) (*bhost.BasicHost, error) { h, err := bhost.NewHost(swrm, &bhost.HostOpts{ EventBus: eventBus, ConnManager: cfg.ConnManager, @@ -437,8 +430,7 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B EnableMetrics: !cfg.DisableMetrics, PrometheusRegisterer: cfg.PrometheusRegisterer, DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery, - EnableAutoNATv2: cfg.EnableAutoNATv2, - AutoNATv2Dialer: autonatv2Dialer, + AutoNATv2: an, }) if err != nil { return nil, err @@ -517,6 +509,24 @@ func (cfg *Config) NewNode() (host.Host, error) { }) return sw, nil }), + fx.Provide(func() (*autonatv2.AutoNAT, error) { + if !cfg.EnableAutoNATv2 { + return nil, nil + } + ah, err := cfg.makeAutoNATV2Host() + if err != nil { + return nil, err + } + var mt autonatv2.MetricsTracer + if !cfg.DisableMetrics { + mt = autonatv2.NewMetricsTracer(cfg.PrometheusRegisterer) + } + autoNATv2, err := autonatv2.New(ah, autonatv2.WithMetricsTracer(mt)) + if err != nil { + return nil, fmt.Errorf("failed to create autonatv2: %w", err) + } + return autoNATv2, nil + }), fx.Provide(cfg.newBasicHost), fx.Provide(func(bh *bhost.BasicHost) identify.IDService { return bh.IDService() diff --git a/core/event/reachability.go b/core/event/reachability.go index 6ab4523851..8aa2cd07cc 100644 --- a/core/event/reachability.go +++ b/core/event/reachability.go @@ -2,6 +2,7 @@ package event import ( "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" ) // EvtLocalReachabilityChanged is an event struct to be emitted when the local's @@ -11,3 +12,12 @@ import ( type EvtLocalReachabilityChanged struct { Reachability network.Reachability } + +// EvtHostReachableAddrsChanged is sent when host's reachable or unreachable addresses change +// Reachable and Unreachable both contain only Public IP or DNS addresses +// +// Experimental: This API is unstable. Any changes to this event will be done without a deprecation notice. +type EvtHostReachableAddrsChanged struct { + Reachable []ma.Multiaddr + Unreachable []ma.Multiaddr +} diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index 6b984a9dd5..217570ae86 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -2,6 +2,7 @@ package basichost import ( "context" + "errors" "fmt" "net" "slices" @@ -13,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/libp2p/go-netroute" @@ -27,24 +29,36 @@ type observedAddrsManager interface { ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr } +type hostAddrs struct { + addrs []ma.Multiaddr + localAddrs []ma.Multiaddr + reachableAddrs []ma.Multiaddr + unreachableAddrs []ma.Multiaddr + relayAddrs []ma.Multiaddr +} + type addrsManager struct { - eventbus event.Bus - natManager NATManager - addrsFactory AddrsFactory - listenAddrs func() []ma.Multiaddr - transportForListening func(ma.Multiaddr) transport.Transport - observedAddrsManager observedAddrsManager - interfaceAddrs *interfaceAddrsCache + bus event.Bus + natManager NATManager + addrsFactory AddrsFactory + listenAddrs func() []ma.Multiaddr + transportForListening func(ma.Multiaddr) transport.Transport + observedAddrsManager observedAddrsManager + interfaceAddrs *interfaceAddrsCache + addrsReachabilityTracker *addrsReachabilityTracker + + // addrsUpdatedChan is notified when addrs change. This is provided by the caller. + addrsUpdatedChan chan struct{} // triggerAddrsUpdateChan is used to trigger an addresses update. triggerAddrsUpdateChan chan struct{} - // addrsUpdatedChan is notified when addresses change. - addrsUpdatedChan chan struct{} + // triggerReachabilityUpdate is notified when reachable addrs are updated. + triggerReachabilityUpdate chan struct{} + hostReachability atomic.Pointer[network.Reachability] - addrsMx sync.RWMutex // protects fields below - localAddrs []ma.Multiaddr - relayAddrs []ma.Multiaddr + addrsMx sync.RWMutex + currentAddrs hostAddrs wg sync.WaitGroup ctx context.Context @@ -52,35 +66,49 @@ type addrsManager struct { } func newAddrsManager( - eventbus event.Bus, + bus event.Bus, natmgr NATManager, addrsFactory AddrsFactory, listenAddrs func() []ma.Multiaddr, transportForListening func(ma.Multiaddr) transport.Transport, observedAddrsManager observedAddrsManager, addrsUpdatedChan chan struct{}, + client autonatv2Client, ) (*addrsManager, error) { ctx, cancel := context.WithCancel(context.Background()) as := &addrsManager{ - eventbus: eventbus, - listenAddrs: listenAddrs, - transportForListening: transportForListening, - observedAddrsManager: observedAddrsManager, - natManager: natmgr, - addrsFactory: addrsFactory, - triggerAddrsUpdateChan: make(chan struct{}, 1), - addrsUpdatedChan: addrsUpdatedChan, - interfaceAddrs: &interfaceAddrsCache{}, - ctx: ctx, - ctxCancel: cancel, + bus: bus, + listenAddrs: listenAddrs, + transportForListening: transportForListening, + observedAddrsManager: observedAddrsManager, + natManager: natmgr, + addrsFactory: addrsFactory, + triggerAddrsUpdateChan: make(chan struct{}, 1), + triggerReachabilityUpdate: make(chan struct{}, 1), + addrsUpdatedChan: addrsUpdatedChan, + interfaceAddrs: &interfaceAddrsCache{}, + ctx: ctx, + ctxCancel: cancel, } unknownReachability := network.ReachabilityUnknown as.hostReachability.Store(&unknownReachability) + + if client != nil { + as.addrsReachabilityTracker = newAddrsReachabilityTracker(client, as.triggerReachabilityUpdate, nil) + } return as, nil } func (a *addrsManager) Start() error { - return a.background() + // TODO: add Start method to NATMgr + if a.addrsReachabilityTracker != nil { + err := a.addrsReachabilityTracker.Start() + if err != nil { + return fmt.Errorf("error starting addrs reachability tracker: %s", err) + } + } + + return a.startBackgroundWorker() } func (a *addrsManager) Close() { @@ -91,10 +119,18 @@ func (a *addrsManager) Close() { log.Warnf("error closing natmgr: %s", err) } } + if a.addrsReachabilityTracker != nil { + err := a.addrsReachabilityTracker.Close() + if err != nil { + log.Warnf("error closing addrs reachability tracker: %s", err) + } + } a.wg.Wait() } func (a *addrsManager) NetNotifee() network.Notifiee { + // Updating addrs in sync provides the nice property that + // host.Addrs() just after host.Network().Listen(x) will return x return &network.NotifyBundle{ ListenF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() }, ListenCloseF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() }, @@ -102,37 +138,53 @@ func (a *addrsManager) NetNotifee() network.Notifiee { } func (a *addrsManager) triggerAddrsUpdate() { - // This is ugly, we update here *and* in the background loop, but this ensures the nice property - // that host.Addrs after host.Network().Listen(...) will return the recently added listen address. - a.updateLocalAddrs() + a.updateAddrs(false, nil) select { case a.triggerAddrsUpdateChan <- struct{}{}: default: } } -func (a *addrsManager) background() error { - autoRelayAddrsSub, err := a.eventbus.Subscribe(new(event.EvtAutoRelayAddrsUpdated)) +func (a *addrsManager) startBackgroundWorker() error { + autoRelayAddrsSub, err := a.bus.Subscribe(new(event.EvtAutoRelayAddrsUpdated), eventbus.Name("addrs-manager")) if err != nil { return fmt.Errorf("error subscribing to auto relay addrs: %s", err) } - autonatReachabilitySub, err := a.eventbus.Subscribe(new(event.EvtLocalReachabilityChanged)) + autonatReachabilitySub, err := a.bus.Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("addrs-manager")) + if err != nil { + err1 := autoRelayAddrsSub.Close() + if err1 != nil { + err1 = fmt.Errorf("error closign autorelaysub: %w", err1) + } + err = fmt.Errorf("error subscribing to autonat reachability: %s", err) + return errors.Join(err, err1) + } + + emitter, err := a.bus.Emitter(new(event.EvtHostReachableAddrsChanged), eventbus.Stateful) if err != nil { - return fmt.Errorf("error subscribing to autonat reachability: %s", err) + err1 := autoRelayAddrsSub.Close() + if err1 != nil { + err1 = fmt.Errorf("error closing autorelaysub: %w", err1) + } + err2 := autonatReachabilitySub.Close() + if err2 != nil { + err2 = fmt.Errorf("error closing autonat reachability: %w", err1) + } + err = fmt.Errorf("error subscribing to autonat reachability: %s", err) + return errors.Join(err, err1, err2) } - // ensure that we have the correct address after returning from Start() - // update local addrs - a.updateLocalAddrs() + var relayAddrs []ma.Multiaddr // update relay addrs in case we're private select { case e := <-autoRelayAddrsSub.Out(): if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { - a.updateRelayAddrs(evt.RelayAddrs) + relayAddrs = slices.Clone(evt.RelayAddrs) } default: } + select { case e := <-autonatReachabilitySub.Out(): if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { @@ -140,70 +192,149 @@ func (a *addrsManager) background() error { } default: } + // update addresses before starting the worker loop. This ensures that any address updates + // before calling addrsManager.Start are correctly reported after Start returns. + a.updateAddrs(true, relayAddrs) a.wg.Add(1) - go func() { - defer a.wg.Done() - defer func() { - err := autoRelayAddrsSub.Close() - if err != nil { - log.Warnf("error closing auto relay addrs sub: %s", err) - } - }() - defer func() { - err := autonatReachabilitySub.Close() - if err != nil { - log.Warnf("error closing autonat reachability sub: %s", err) - } - }() - - ticker := time.NewTicker(addrChangeTickrInterval) - defer ticker.Stop() - var prev []ma.Multiaddr - for { - a.updateLocalAddrs() - curr := a.Addrs() - if a.areAddrsDifferent(prev, curr) { - log.Debugf("host addresses updated: %s", curr) - select { - case a.addrsUpdatedChan <- struct{}{}: - default: - } + go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, relayAddrs) + return nil +} + +func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription, + emitter event.Emitter, relayAddrs []ma.Multiaddr, +) { + defer a.wg.Done() + defer func() { + err := autoRelayAddrsSub.Close() + if err != nil { + log.Warnf("error closing auto relay addrs sub: %s", err) + } + err = autonatReachabilitySub.Close() + if err != nil { + log.Warnf("error closing autonat reachability sub: %s", err) + } + }() + + ticker := time.NewTicker(addrChangeTickrInterval) + defer ticker.Stop() + var previousAddrs hostAddrs + for { + currAddrs := a.updateAddrs(true, relayAddrs) + a.notifyAddrsChanged(emitter, previousAddrs, currAddrs) + previousAddrs = currAddrs + select { + case <-ticker.C: + case <-a.triggerAddrsUpdateChan: + case <-a.triggerReachabilityUpdate: + case e := <-autoRelayAddrsSub.Out(): + if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { + relayAddrs = slices.Clone(evt.RelayAddrs) } - prev = curr - select { - case <-ticker.C: - case <-a.triggerAddrsUpdateChan: - case e := <-autoRelayAddrsSub.Out(): - if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { - a.updateRelayAddrs(evt.RelayAddrs) - } - case e := <-autonatReachabilitySub.Out(): - if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { - a.hostReachability.Store(&evt.Reachability) - } - case <-a.ctx.Done(): - return + case e := <-autonatReachabilitySub.Out(): + if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { + a.hostReachability.Store(&evt.Reachability) } + case <-a.ctx.Done(): + return } - }() - return nil + } +} + +// updateAddrs updates the addresses of the host and returns the new updated +// addrs +func (a *addrsManager) updateAddrs(updateRelayAddrs bool, relayAddrs []ma.Multiaddr) hostAddrs { + // Must lock while doing both recompute and update as this method is called from + // multiple goroutines. + a.addrsMx.Lock() + defer a.addrsMx.Unlock() + + localAddrs := a.getLocalAddrs() + var currReachableAddrs, currUnreachableAddrs []ma.Multiaddr + if a.addrsReachabilityTracker != nil { + currReachableAddrs, currUnreachableAddrs = a.getConfirmedAddrs(localAddrs) + } + if !updateRelayAddrs { + relayAddrs = a.currentAddrs.relayAddrs + } else { + // Copy the callers slice + relayAddrs = slices.Clone(relayAddrs) + } + currAddrs := a.getAddrs(slices.Clone(localAddrs), relayAddrs) + + a.currentAddrs = hostAddrs{ + addrs: append(a.currentAddrs.addrs[:0], currAddrs...), + localAddrs: append(a.currentAddrs.localAddrs[:0], localAddrs...), + reachableAddrs: append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...), + unreachableAddrs: append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...), + relayAddrs: append(a.currentAddrs.relayAddrs[:0], relayAddrs...), + } + + return hostAddrs{ + localAddrs: localAddrs, + addrs: currAddrs, + reachableAddrs: currReachableAddrs, + unreachableAddrs: currUnreachableAddrs, + relayAddrs: relayAddrs, + } +} + +func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) { + if areAddrsDifferent(previous.localAddrs, current.localAddrs) { + log.Debugf("host local addresses updated: %s", current.localAddrs) + if a.addrsReachabilityTracker != nil { + a.addrsReachabilityTracker.UpdateAddrs(current.localAddrs) + } + } + if areAddrsDifferent(previous.addrs, current.addrs) { + log.Debugf("host addresses updated: %s", current.localAddrs) + select { + case a.addrsUpdatedChan <- struct{}{}: + default: + } + } + + // We *must* send both reachability changed and addrs changed events from the + // same goroutine to ensure correct ordering + // Consider the events: + // - addr x discovered + // - addr x is reachable + // - addr x removed + // We must send these events in the same order. It'll be confusing for consumers + // if the reachable event is received after the addr removed event. + if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) || + areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) { + log.Debugf("host reachable addrs updated: %s", current.localAddrs) + if err := emitter.Emit(event.EvtHostReachableAddrsChanged{ + Reachable: slices.Clone(current.reachableAddrs), + Unreachable: slices.Clone(current.unreachableAddrs), + }); err != nil { + log.Errorf("error sending host reachable addrs changed event: %s", err) + } + } } // Addrs returns the node's dialable addresses both public and private. // If autorelay is enabled and node reachability is private, it returns // the node's relay addresses and private network addresses. func (a *addrsManager) Addrs() []ma.Multiaddr { - addrs := a.DirectAddrs() + a.addrsMx.RLock() + directAddrs := slices.Clone(a.currentAddrs.localAddrs) + relayAddrs := slices.Clone(a.currentAddrs.relayAddrs) + a.addrsMx.RUnlock() + return a.getAddrs(directAddrs, relayAddrs) +} + +// getAddrs returns the node's dialable addresses. Mutates localAddrs +func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr { + addrs := localAddrs rch := a.hostReachability.Load() if rch != nil && *rch == network.ReachabilityPrivate { - a.addrsMx.RLock() // Delete public addresses if the node's reachability is private, and we have relay addresses - if len(a.relayAddrs) > 0 { + if len(relayAddrs) > 0 { addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr) - addrs = append(addrs, a.relayAddrs...) + addrs = append(addrs, relayAddrs...) } - a.addrsMx.RUnlock() } // Make a copy. Consumers can modify the slice elements addrs = slices.Clone(a.addrsFactory(addrs)) @@ -213,7 +344,8 @@ func (a *addrsManager) Addrs() []ma.Multiaddr { return addrs } -// HolePunchAddrs returns the node's public direct listen addresses for hole punching. +// HolePunchAddrs returns all the host's direct public addresses, reachable or unreachable, +// suitable for hole punching. func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr { addrs := a.DirectAddrs() addrs = slices.Clone(a.addrsFactory(addrs)) @@ -230,26 +362,23 @@ func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr { func (a *addrsManager) DirectAddrs() []ma.Multiaddr { a.addrsMx.RLock() defer a.addrsMx.RUnlock() - return slices.Clone(a.localAddrs) + return slices.Clone(a.currentAddrs.localAddrs) } -func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { - a.addrsMx.Lock() - defer a.addrsMx.Unlock() - a.relayAddrs = append(a.relayAddrs[:0], addrs...) +// ReachableAddrs returns all addresses of the host that are reachable from the internet +func (a *addrsManager) ReachableAddrs() []ma.Multiaddr { + a.addrsMx.RLock() + defer a.addrsMx.RUnlock() + return slices.Clone(a.currentAddrs.reachableAddrs) } -var p2pCircuitAddr = ma.StringCast("/p2p-circuit") - -func (a *addrsManager) updateLocalAddrs() { - localAddrs := a.getLocalAddrs() - slices.SortFunc(localAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) - - a.addrsMx.Lock() - a.localAddrs = localAddrs - a.addrsMx.Unlock() +func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + reachableAddrs, unreachableAddrs = a.addrsReachabilityTracker.ConfirmedAddrs() + return removeNotInSource(reachableAddrs, localAddrs), removeNotInSource(unreachableAddrs, localAddrs) } +var p2pCircuitAddr = ma.StringCast("/p2p-circuit") + func (a *addrsManager) getLocalAddrs() []ma.Multiaddr { listenAddrs := a.listenAddrs() if len(listenAddrs) == 0 { @@ -260,8 +389,6 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr { finalAddrs = a.appendPrimaryInterfaceAddrs(finalAddrs, listenAddrs) finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs, a.interfaceAddrs.All()) - finalAddrs = ma.Unique(finalAddrs) - // Remove "/p2p-circuit" addresses from the list. // The p2p-circuit listener reports its address as just /p2p-circuit. This is // useless for dialing. Users need to manage their circuit addresses themselves, @@ -278,6 +405,8 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr { // Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered // using identify. finalAddrs = a.addCertHashes(finalAddrs) + finalAddrs = ma.Unique(finalAddrs) + slices.SortFunc(finalAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) return finalAddrs } @@ -408,7 +537,7 @@ func (a *addrsManager) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } -func (a *addrsManager) areAddrsDifferent(prev, current []ma.Multiaddr) bool { +func areAddrsDifferent(prev, current []ma.Multiaddr) bool { // TODO: make the sorted nature of ma.Unique a guarantee in multiaddrs prev = ma.Unique(prev) current = ma.Unique(current) @@ -547,3 +676,31 @@ func (i *interfaceAddrsCache) updateUnlocked() { } } } + +// removeNotInSource removes items from addrs that are not present in source. +// Modifies the addrs slice in place +// addrs and source must be sorted using multiaddr.Compare. +func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr { + j := 0 + // mark entries not in source as nil + for i, a := range addrs { + // move right as long as a > source[j] + for j < len(source) && a.Compare(source[j]) > 0 { + j++ + } + // a is not in source if we've reached the end, or a is lesser + if j == len(source) || a.Compare(source[j]) < 0 { + addrs[i] = nil + } + // a is in source, nothing to do + } + // j is the current element, i is the lowest index nil element + i := 0 + for j := range len(addrs) { + if addrs[j] != nil { + addrs[i], addrs[j] = addrs[j], addrs[i] + i++ + } + } + return addrs[:i] +} diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index 49e46f2530..56f9faaf42 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -1,13 +1,17 @@ package basichost import ( + "context" + "errors" "fmt" + "slices" "testing" "time" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/assert" @@ -30,7 +34,7 @@ func TestAppendNATAddrs(t *testing.T) { // nat mapping success, obsaddress ignored Listen: ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1"), Nat: ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1"), - ObsAddrFunc: func(m ma.Multiaddr) []ma.Multiaddr { + ObsAddrFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/udp/100/quic-v1")} }, Expected: []ma.Multiaddr{ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1")}, @@ -116,7 +120,7 @@ func TestAppendNATAddrs(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { as := &addrsManager{ natManager: &mockNatManager{ - GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr { + GetMappingFunc: func(_ ma.Multiaddr) ma.Multiaddr { return tc.Nat }, }, @@ -135,7 +139,7 @@ type mockNatManager struct { GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr } -func (m *mockNatManager) Close() error { +func (*mockNatManager) Close() error { return nil } @@ -146,7 +150,7 @@ func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { return m.GetMappingFunc(addr) } -func (m *mockNatManager) HasDiscoveredNAT() bool { +func (*mockNatManager) HasDiscoveredNAT() bool { return true } @@ -170,6 +174,8 @@ type addrsManagerArgs struct { AddrsFactory AddrsFactory ObservedAddrsManager observedAddrsManager ListenAddrs func() []ma.Multiaddr + AutoNATClient autonatv2Client + Bus event.Bus } type addrsManagerTestCase struct { @@ -179,13 +185,16 @@ type addrsManagerTestCase struct { } func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTestCase { - eb := eventbus.NewBus() + eb := args.Bus + if eb == nil { + eb = eventbus.NewBus() + } if args.AddrsFactory == nil { args.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } } addrsUpdatedChan := make(chan struct{}, 1) am, err := newAddrsManager( - eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, + eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, args.AutoNATClient, ) require.NoError(t, err) @@ -196,6 +205,7 @@ func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTe rchEm, err := eb.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) require.NoError(t, err) + t.Cleanup(am.Close) return addrsManagerTestCase{ addrsManager: am, PushRelay: func(relayAddrs []ma.Multiaddr) { @@ -326,7 +336,7 @@ func TestAddrsManager(t *testing.T) { } am := newAddrsManagerTestCase(t, addrsManagerArgs{ ObservedAddrsManager: &mockObservedAddrs{ - ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return quicAddrs }, }, @@ -342,7 +352,7 @@ func TestAddrsManager(t *testing.T) { t.Run("public addrs removed when private", func(t *testing.T) { am := newAddrsManagerTestCase(t, addrsManagerArgs{ ObservedAddrsManager: &mockObservedAddrs{ - ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return []ma.Multiaddr{publicQUIC} }, }, @@ -384,7 +394,7 @@ func TestAddrsManager(t *testing.T) { return nil }, ObservedAddrsManager: &mockObservedAddrs{ - ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr { return []ma.Multiaddr{publicQUIC} }, }, @@ -404,7 +414,7 @@ func TestAddrsManager(t *testing.T) { t.Run("updates addresses on signaling", func(t *testing.T) { updateChan := make(chan struct{}) am := newAddrsManagerTestCase(t, addrsManagerArgs{ - AddrsFactory: func(addrs []ma.Multiaddr) []ma.Multiaddr { + AddrsFactory: func(_ []ma.Multiaddr) []ma.Multiaddr { select { case <-updateChan: return []ma.Multiaddr{publicQUIC} @@ -425,17 +435,95 @@ func TestAddrsManager(t *testing.T) { }) } +func TestAddrsManagerReachabilityEvent(t *testing.T) { + publicQUIC, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1") + publicQUIC2, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1235/quic-v1") + publicTCP, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") + + bus := eventbus.NewBus() + + sub, err := bus.Subscribe(new(event.EvtHostReachableAddrsChanged)) + require.NoError(t, err) + defer sub.Close() + + am := newAddrsManagerTestCase(t, addrsManagerArgs{ + Bus: bus, + // currently they aren't being passed to the reachability tracker + ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicQUIC2, publicTCP} }, + AutoNATClient: mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + if reqs[0].Addr.Equal(publicQUIC) { + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil + } else if reqs[0].Addr.Equal(publicTCP) || reqs[0].Addr.Equal(publicQUIC2) { + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPrivate}, nil + } + return autonatv2.Result{}, errors.New("invalid") + }, + }, + }) + + reachableAddrs := []ma.Multiaddr{publicQUIC} + unreachableAddrs := []ma.Multiaddr{publicTCP, publicQUIC2} + select { + case e := <-sub.Out(): + evt := e.(event.EvtHostReachableAddrsChanged) + require.ElementsMatch(t, reachableAddrs, evt.Reachable) + require.ElementsMatch(t, unreachableAddrs, evt.Unreachable) + require.ElementsMatch(t, reachableAddrs, am.ReachableAddrs()) + case <-time.After(5 * time.Second): + t.Fatal("expected event for reachability change") + } +} + +func TestRemoveIfNotInSource(t *testing.T) { + var addrs []ma.Multiaddr + for i := 0; i < 10; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", i))) + } + slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + cases := []struct { + addrs []ma.Multiaddr + source []ma.Multiaddr + expected []ma.Multiaddr + }{ + {}, + {addrs: slices.Clone(addrs[:5]), source: nil, expected: nil}, + {addrs: nil, source: addrs, expected: nil}, + {addrs: []ma.Multiaddr{addrs[0]}, source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}}, + {addrs: slices.Clone(addrs), source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}}, + {addrs: slices.Clone(addrs), source: slices.Clone(addrs[5:]), expected: slices.Clone(addrs[5:])}, + {addrs: slices.Clone(addrs[:5]), source: []ma.Multiaddr{addrs[0], addrs[2], addrs[8]}, expected: []ma.Multiaddr{addrs[0], addrs[2]}}, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + addrs := removeNotInSource(tc.addrs, tc.source) + require.ElementsMatch(t, tc.expected, addrs, "%s\n%s", tc.expected, tc.addrs) + }) + } +} + func BenchmarkAreAddrsDifferent(b *testing.B) { var addrs [10]ma.Multiaddr for i := 0; i < len(addrs); i++ { addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i)) } - am := &addrsManager{} b.Run("areAddrsDifferent", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - am.areAddrsDifferent(addrs[:], addrs[:]) + areAddrsDifferent(addrs[:], addrs[:]) } }) } + +func BenchmarkRemoveIfNotInSource(b *testing.B) { + var addrs [10]ma.Multiaddr + for i := 0; i < len(addrs); i++ { + addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i)) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + removeNotInSource(slices.Clone(addrs[:5]), addrs[:]) + } +} diff --git a/p2p/host/basic/addrs_reachability_tracker.go b/p2p/host/basic/addrs_reachability_tracker.go new file mode 100644 index 0000000000..2d09a34ebc --- /dev/null +++ b/p2p/host/basic/addrs_reachability_tracker.go @@ -0,0 +1,666 @@ +package basichost + +import ( + "context" + "errors" + "fmt" + "math" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +type autonatv2Client interface { + GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) +} + +const ( + + // maxAddrsPerRequest is the maximum number of addresses to probe in a single request + maxAddrsPerRequest = 10 + // maxTrackedAddrs is the maximum number of addresses to track + // 10 addrs per transport for 5 transports + maxTrackedAddrs = 50 + // defaultMaxConcurrency is the default number of concurrent workers for reachability checks + defaultMaxConcurrency = 5 + // newAddrsProbeDelay is the delay before probing new addr's reachability. + newAddrsProbeDelay = 1 * time.Second +) + +// addrsReachabilityTracker tracks reachability for addresses. +// Use UpdateAddrs to provide addresses for tracking reachability. +// reachabilityUpdateCh is notified when reachability for any of the tracked address changes. +type addrsReachabilityTracker struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + client autonatv2Client + // reachabilityUpdateCh is used to notify when reachability may have changed + reachabilityUpdateCh chan struct{} + maxConcurrency int + newAddrsProbeDelay time.Duration + probeManager *probeManager + newAddrs chan []ma.Multiaddr + clock clock.Clock + + mx sync.Mutex + reachableAddrs []ma.Multiaddr + unreachableAddrs []ma.Multiaddr +} + +// newAddrsReachabilityTracker returns a new addrsReachabilityTracker. +// reachabilityUpdateCh is notified when reachability for any of the tracked address changes. +func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh chan struct{}, cl clock.Clock) *addrsReachabilityTracker { + ctx, cancel := context.WithCancel(context.Background()) + if cl == nil { + cl = clock.New() + } + return &addrsReachabilityTracker{ + ctx: ctx, + cancel: cancel, + client: client, + reachabilityUpdateCh: reachabilityUpdateCh, + probeManager: newProbeManager(cl.Now), + newAddrsProbeDelay: newAddrsProbeDelay, + maxConcurrency: defaultMaxConcurrency, + newAddrs: make(chan []ma.Multiaddr, 1), + clock: cl, + } +} + +func (r *addrsReachabilityTracker) UpdateAddrs(addrs []ma.Multiaddr) { + select { + case r.newAddrs <- slices.Clone(addrs): + case <-r.ctx.Done(): + } +} + +func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + r.mx.Lock() + defer r.mx.Unlock() + return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs) +} + +func (r *addrsReachabilityTracker) Start() error { + r.wg.Add(1) + go r.background() + return nil +} + +func (r *addrsReachabilityTracker) Close() error { + r.cancel() + r.wg.Wait() + return nil +} + +const ( + // defaultReachabilityRefreshInterval is the default interval to refresh reachability. + // In steady state, we check for any required probes every refresh interval. + // This doesn't mean we'll probe for any particular address, only that we'll check + // if any address needs to be probed. + defaultReachabilityRefreshInterval = 5 * time.Minute + // maxBackoffInterval is the maximum back off in case we're unable to probe for reachability. + // We may be unable to confirm addresses in case there are no valid peers with autonatv2 + // or the autonatv2 subsystem is consistently erroring. + maxBackoffInterval = 5 * time.Minute + // backoffStartInterval is the initial back off in case we're unable to probe for reachability. + backoffStartInterval = 5 * time.Second +) + +func (r *addrsReachabilityTracker) background() { + defer r.wg.Done() + + // probeTicker is used to trigger probes at regular intervals + probeTicker := r.clock.Ticker(defaultReachabilityRefreshInterval) + defer probeTicker.Stop() + + // probeTimer is used to trigger probes at specific times + probeTimer := r.clock.Timer(time.Duration(math.MaxInt64)) + defer probeTimer.Stop() + nextProbeTime := time.Time{} + + var task reachabilityTask + var backoffInterval time.Duration + var currReachable, currUnreachable, prevReachable, prevUnreachable []ma.Multiaddr + for { + select { + case <-probeTicker.C: + // don't start a probe if we have a scheduled probe + if task.BackoffCh == nil && nextProbeTime.IsZero() { + task = r.refreshReachability() + } + case <-probeTimer.C: + if task.BackoffCh == nil { + task = r.refreshReachability() + } + nextProbeTime = time.Time{} + case backoff := <-task.BackoffCh: + task = reachabilityTask{} + // On completion, start the next probe immediately, or wait for backoff. + // In case there are no further probes, the reachability tracker will return an empty task, + // which hangs forever. Eventually, we'll refresh again when the ticker fires. + if backoff { + backoffInterval = newBackoffInterval(backoffInterval) + } else { + backoffInterval = -1 * time.Second // negative to trigger next probe immediately + } + nextProbeTime = r.clock.Now().Add(backoffInterval) + case addrs := <-r.newAddrs: + if task.BackoffCh != nil { // cancel running task. + task.Cancel() + <-task.BackoffCh // ignore backoff from cancelled task + task = reachabilityTask{} + } + r.updateTrackedAddrs(addrs) + newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay) + if nextProbeTime.Before(newAddrsNextTime) { + nextProbeTime = newAddrsNextTime + } + case <-r.ctx.Done(): + if task.BackoffCh != nil { + task.Cancel() + <-task.BackoffCh + task = reachabilityTask{} + } + return + } + + currReachable, currUnreachable = r.appendConfirmedAddrs(currReachable[:0], currUnreachable[:0]) + if areAddrsDifferent(prevReachable, currReachable) || areAddrsDifferent(prevUnreachable, currUnreachable) { + r.notify() + } + prevReachable = append(prevReachable[:0], currReachable...) + prevUnreachable = append(prevUnreachable[:0], currUnreachable...) + if !nextProbeTime.IsZero() { + probeTimer.Reset(nextProbeTime.Sub(r.clock.Now())) + } + } +} + +func newBackoffInterval(current time.Duration) time.Duration { + if current <= 0 { + return backoffStartInterval + } + current *= 2 + if current > maxBackoffInterval { + return maxBackoffInterval + } + return current +} + +func (r *addrsReachabilityTracker) appendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + reachable, unreachable = r.probeManager.AppendConfirmedAddrs(reachable, unreachable) + r.mx.Lock() + r.reachableAddrs = append(r.reachableAddrs[:0], reachable...) + r.unreachableAddrs = append(r.unreachableAddrs[:0], unreachable...) + r.mx.Unlock() + return reachable, unreachable +} + +func (r *addrsReachabilityTracker) notify() { + select { + case r.reachabilityUpdateCh <- struct{}{}: + default: + } +} + +func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) { + addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { + return !manet.IsPublicAddr(a) + }) + if len(addrs) > maxTrackedAddrs { + log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs) + addrs = addrs[:maxTrackedAddrs] + } + r.probeManager.UpdateAddrs(addrs) +} + +type probe = []autonatv2.Request + +const probeTimeout = 30 * time.Second + +// reachabilityTask is a task to refresh reachability. +// Waiting on the zero value blocks forever. +type reachabilityTask struct { + Cancel context.CancelFunc + // BackoffCh returns whether the caller should backoff before + // refreshing reachability + BackoffCh chan bool +} + +func (r *addrsReachabilityTracker) refreshReachability() reachabilityTask { + if len(r.probeManager.GetProbe()) == 0 { + return reachabilityTask{} + } + resCh := make(chan bool, 1) + ctx, cancel := context.WithTimeout(r.ctx, 5*time.Minute) + r.wg.Add(1) + // We run probes provided by addrsTracker. It stops probing when any + // of the following happens: + // - there are no more probes to run + // - context is completed + // - there are too many consecutive failures from the client + // - the client has no valid peers to probe + go func() { + defer r.wg.Done() + defer cancel() + client := &errCountingClient{autonatv2Client: r.client, MaxConsecutiveErrors: maxConsecutiveErrors} + var backoff atomic.Bool + var wg sync.WaitGroup + wg.Add(r.maxConcurrency) + for range r.maxConcurrency { + go func() { + defer wg.Done() + for { + if ctx.Err() != nil { + return + } + reqs := r.probeManager.GetProbe() + if len(reqs) == 0 { + return + } + r.probeManager.MarkProbeInProgress(reqs) + rctx, cancel := context.WithTimeout(ctx, probeTimeout) + res, err := client.GetReachability(rctx, reqs) + cancel() + r.probeManager.CompleteProbe(reqs, res, err) + if isErrorPersistent(err) { + backoff.Store(true) + return + } + } + }() + } + wg.Wait() + resCh <- backoff.Load() + }() + return reachabilityTask{Cancel: cancel, BackoffCh: resCh} +} + +var errTooManyConsecutiveFailures = errors.New("too many consecutive failures") + +// errCountingClient counts errors from autonatv2Client and wraps the errors in response with a +// errTooManyConsecutiveFailures in case of persistent failures from autonatv2 module. +type errCountingClient struct { + autonatv2Client + MaxConsecutiveErrors int + mx sync.Mutex + consecutiveErrors int +} + +func (c *errCountingClient) GetReachability(ctx context.Context, reqs probe) (autonatv2.Result, error) { + res, err := c.autonatv2Client.GetReachability(ctx, reqs) + c.mx.Lock() + defer c.mx.Unlock() + if err != nil && !errors.Is(err, context.Canceled) { // ignore canceled errors, they're not errors from autonatv2 + c.consecutiveErrors++ + if c.consecutiveErrors > c.MaxConsecutiveErrors { + err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err) + } + if errors.Is(err, autonatv2.ErrPrivateAddrs) { + log.Errorf("private IP addr in autonatv2 request: %s", err) + } + } else { + c.consecutiveErrors = 0 + } + return res, err +} + +const maxConsecutiveErrors = 20 + +// isErrorPersistent returns whether the error will repeat on future probes for a while +func isErrorPersistent(err error) bool { + if err == nil { + return false + } + return errors.Is(err, autonatv2.ErrPrivateAddrs) || errors.Is(err, autonatv2.ErrNoPeers) || + errors.Is(err, errTooManyConsecutiveFailures) +} + +const ( + // recentProbeInterval is the interval to probe addresses that have been refused + // these are generally addresses with newer transports for which we don't have many peers + // capable of dialing the transport + recentProbeInterval = 10 * time.Minute + // maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which + // we wait for `recentProbeInterval` before probing again + maxConsecutiveRefusals = 5 + // maxRecentDialsPerAddr is the maximum number of dials on an address before we stop probing for the address. + // This is used to prevent infinite probing of an address whose status is indeterminate for any reason. + maxRecentDialsPerAddr = 10 + // confidence is the absolute difference between the number of successes and failures for an address + // targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval` + // before probing again. + targetConfidence = 3 + // minConfidence is the confidence threshold for an address to be considered reachable or unreachable + // confidence is the absolute difference between the number of successes and failures for an address + minConfidence = 2 + // maxRecentDialsWindow is the maximum number of recent probe results to consider for a single address + // + // +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure + // and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to + // targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval. + maxRecentDialsWindow = targetConfidence + 2 + // highConfidenceAddrProbeInterval is the maximum interval between probes for an address + highConfidenceAddrProbeInterval = 1 * time.Hour + // maxProbeResultTTL is the maximum time to keep probe results for an address + maxProbeResultTTL = maxRecentDialsWindow * highConfidenceAddrProbeInterval +) + +// probeManager tracks reachability for a set of addresses by periodically probing reachability with autonatv2. +// A Probe is a list of addresses which can be tested for reachability with autonatv2. +// This struct decides the priority order of addresses for testing reachability, and throttles in case there have +// been too many probes for an address in the `ProbeInterval`. +// +// Use the `runProbes` function to execute the probes with an autonatv2 client. +type probeManager struct { + now func() time.Time + + mx sync.Mutex + inProgressProbes map[string]int // addr -> count + inProgressProbesTotal int + statuses map[string]*addrStatus + addrs []ma.Multiaddr +} + +// newProbeManager creates a new probe manager. +func newProbeManager(now func() time.Time) *probeManager { + return &probeManager{ + statuses: make(map[string]*addrStatus), + inProgressProbes: make(map[string]int), + now: now, + } +} + +// AppendConfirmedAddrs appends the current confirmed reachable and unreachable addresses. +func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) { + m.mx.Lock() + defer m.mx.Unlock() + + for _, a := range m.addrs { + s := m.statuses[string(a.Bytes())] + s.RemoveBefore(m.now().Add(-maxProbeResultTTL)) // cleanup stale results + switch s.Reachability() { + case network.ReachabilityPublic: + reachable = append(reachable, a) + case network.ReachabilityPrivate: + unreachable = append(unreachable, a) + } + } + return reachable, unreachable +} + +// UpdateAddrs updates the tracked addrs +func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) { + m.mx.Lock() + defer m.mx.Unlock() + + slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) }) + statuses := make(map[string]*addrStatus, len(addrs)) + for _, addr := range addrs { + k := string(addr.Bytes()) + if _, ok := m.statuses[k]; !ok { + statuses[k] = &addrStatus{Addr: addr} + } else { + statuses[k] = m.statuses[k] + } + } + m.addrs = addrs + m.statuses = statuses +} + +// GetProbe returns the next probe. Returns zero value in case there are no more probes. +// Probes that are run against an autonatv2 client should be marked in progress with +// `MarkProbeInProgress` before running. +func (m *probeManager) GetProbe() probe { + m.mx.Lock() + defer m.mx.Unlock() + + now := m.now() + for i, a := range m.addrs { + ab := a.Bytes() + pc := m.statuses[string(ab)].RequiredProbeCount(now) + if m.inProgressProbes[string(ab)] >= pc { + continue + } + reqs := make(probe, 0, maxAddrsPerRequest) + reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true}) + // We have the first(primary) address. Append other addresses, ignoring inprogress probes + // on secondary addresses. The expectation is that the primary address will + // be dialed. + for j := 1; j < len(m.addrs); j++ { + k := (i + j) % len(m.addrs) + ab := m.addrs[k].Bytes() + pc := m.statuses[string(ab)].RequiredProbeCount(now) + if pc == 0 { + continue + } + reqs = append(reqs, autonatv2.Request{Addr: m.addrs[k], SendDialData: true}) + if len(reqs) >= maxAddrsPerRequest { + break + } + } + return reqs + } + return nil +} + +// MarkProbeInProgress should be called when a probe is started. +// All in progress probes *MUST* be completed with `CompleteProbe` +func (m *probeManager) MarkProbeInProgress(reqs probe) { + if len(reqs) == 0 { + return + } + m.mx.Lock() + defer m.mx.Unlock() + m.inProgressProbes[string(reqs[0].Addr.Bytes())]++ + m.inProgressProbesTotal++ +} + +// InProgressProbes returns the number of probes that are currently in progress. +func (m *probeManager) InProgressProbes() int { + m.mx.Lock() + defer m.mx.Unlock() + return m.inProgressProbesTotal +} + +// CompleteProbe should be called when a probe completes. +func (m *probeManager) CompleteProbe(reqs probe, res autonatv2.Result, err error) { + now := m.now() + + if len(reqs) == 0 { + // should never happen + return + } + + m.mx.Lock() + defer m.mx.Unlock() + + // decrement in-progress count for the first address + primaryAddrKey := string(reqs[0].Addr.Bytes()) + m.inProgressProbes[primaryAddrKey]-- + if m.inProgressProbes[primaryAddrKey] <= 0 { + delete(m.inProgressProbes, primaryAddrKey) + } + m.inProgressProbesTotal-- + + // nothing to do if the request errored. + if err != nil { + return + } + + // Consider only primary address as refused. This increases the number of + // refused probes, but refused probes are cheap for a server as no dials are made. + if res.AllAddrsRefused { + if s, ok := m.statuses[primaryAddrKey]; ok { + s.AddRefusal(now) + } + return + } + dialAddrKey := string(res.Addr.Bytes()) + if dialAddrKey != primaryAddrKey { + if s, ok := m.statuses[primaryAddrKey]; ok { + s.AddRefusal(now) + } + } + + // record the result for the dialed address + if s, ok := m.statuses[dialAddrKey]; ok { + s.AddOutcome(now, res.Reachability, maxRecentDialsWindow) + } +} + +type dialOutcome struct { + Success bool + At time.Time +} + +type addrStatus struct { + Addr ma.Multiaddr + lastRefusalTime time.Time + consecutiveRefusals int + dialTimes []time.Time + outcomes []dialOutcome +} + +func (s *addrStatus) Reachability() network.Reachability { + rch, _, _ := s.reachabilityAndCounts() + return rch +} + +func (s *addrStatus) RequiredProbeCount(now time.Time) int { + if s.consecutiveRefusals >= maxConsecutiveRefusals { + if now.Sub(s.lastRefusalTime) < recentProbeInterval { + return 0 + } + // reset every `recentProbeInterval` + s.lastRefusalTime = time.Time{} + s.consecutiveRefusals = 0 + } + + // Don't probe if we have probed too many times recently + rd := s.recentDialCount(now) + if rd >= maxRecentDialsPerAddr { + return 0 + } + + return s.requiredProbeCountForConfirmation(now) +} + +func (s *addrStatus) requiredProbeCountForConfirmation(now time.Time) int { + reachability, successes, failures := s.reachabilityAndCounts() + confidence := successes - failures + if confidence < 0 { + confidence = -confidence + } + cnt := targetConfidence - confidence + if cnt > 0 { + return cnt + } + // we have enough confirmations; check if we should refresh + + // Should never happen. The confidence logic above should require a few probes. + if len(s.outcomes) == 0 { + return 0 + } + lastOutcome := s.outcomes[len(s.outcomes)-1] + // If the last probe result is old, we need to retest + if now.Sub(lastOutcome.At) > highConfidenceAddrProbeInterval { + return 1 + } + // if the last probe result was different from reachability, probe again. + switch reachability { + case network.ReachabilityPublic: + if !lastOutcome.Success { + return 1 + } + case network.ReachabilityPrivate: + if lastOutcome.Success { + return 1 + } + default: + // this should never happen + return 1 + } + return 0 +} + +func (s *addrStatus) AddRefusal(now time.Time) { + s.lastRefusalTime = now + s.consecutiveRefusals++ +} + +func (s *addrStatus) AddOutcome(at time.Time, rch network.Reachability, windowSize int) { + s.lastRefusalTime = time.Time{} + s.consecutiveRefusals = 0 + + s.dialTimes = append(s.dialTimes, at) + for i, t := range s.dialTimes { + if at.Sub(t) < recentProbeInterval { + s.dialTimes = slices.Delete(s.dialTimes, 0, i) + break + } + } + + s.RemoveBefore(at.Add(-maxProbeResultTTL)) // remove old outcomes + success := false + switch rch { + case network.ReachabilityPublic: + success = true + case network.ReachabilityPrivate: + success = false + default: + return // don't store the outcome if reachability is unknown + } + s.outcomes = append(s.outcomes, dialOutcome{At: at, Success: success}) + if len(s.outcomes) > windowSize { + s.outcomes = slices.Delete(s.outcomes, 0, len(s.outcomes)-windowSize) + } +} + +// RemoveBefore removes outcomes before t +func (s *addrStatus) RemoveBefore(t time.Time) { + end := 0 + for ; end < len(s.outcomes); end++ { + if !s.outcomes[end].At.Before(t) { + break + } + } + s.outcomes = slices.Delete(s.outcomes, 0, end) +} + +func (s *addrStatus) recentDialCount(now time.Time) int { + cnt := 0 + for _, t := range slices.Backward(s.dialTimes) { + if now.Sub(t) > recentProbeInterval { + break + } + cnt++ + } + return cnt +} + +func (s *addrStatus) reachabilityAndCounts() (rch network.Reachability, successes int, failures int) { + for _, r := range s.outcomes { + if r.Success { + successes++ + } else { + failures++ + } + } + if successes-failures >= minConfidence { + return network.ReachabilityPublic, successes, failures + } + if failures-successes >= minConfidence { + return network.ReachabilityPrivate, successes, failures + } + return network.ReachabilityUnknown, successes, failures +} diff --git a/p2p/host/basic/addrs_reachability_tracker_test.go b/p2p/host/basic/addrs_reachability_tracker_test.go new file mode 100644 index 0000000000..a58b60db48 --- /dev/null +++ b/p2p/host/basic/addrs_reachability_tracker_test.go @@ -0,0 +1,919 @@ +package basichost + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "math/rand" + "net" + "net/netip" + "slices" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProbeManager(t *testing.T) { + pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1") + pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1") + + cl := clock.NewMock() + + nextProbe := func(pm *probeManager) []autonatv2.Request { + reqs := pm.GetProbe() + if len(reqs) != 0 { + pm.MarkProbeInProgress(reqs) + } + return reqs + } + + makeNewProbeManager := func(addrs []ma.Multiaddr) *probeManager { + pm := newProbeManager(cl.Now) + pm.UpdateAddrs(addrs) + return pm + } + + t.Run("addrs updates", func(t *testing.T) { + pm := newProbeManager(cl.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub1, pub2}) + for { + reqs := nextProbe(pm) + if len(reqs) == 0 { + break + } + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } + reachable, _ := pm.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1, pub2}) + pm.UpdateAddrs([]ma.Multiaddr{pub3}) + + reachable, _ = pm.AppendConfirmedAddrs(nil, nil) + require.Empty(t, reachable) + require.Len(t, pm.statuses, 1) + }) + + t.Run("inprogress", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + reqs1 := pm.GetProbe() + reqs2 := pm.GetProbe() + require.Equal(t, reqs1, reqs2) + for range targetConfidence { + reqs := nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + } + for range targetConfidence { + reqs := nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}}) + } + reqs := pm.GetProbe() + require.Empty(t, reqs) + }) + + t.Run("refusals", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + var probes [][]autonatv2.Request + for range targetConfidence { + reqs := nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + probes = append(probes, reqs) + } + // first one refused second one successful + for _, p := range probes { + pm.CompleteProbe(p, autonatv2.Result{Addr: pub2, Idx: 1, Reachability: network.ReachabilityPublic}, nil) + } + // the second address is validated! + probes = nil + for range targetConfidence { + reqs := nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) + probes = append(probes, reqs) + } + reqs := pm.GetProbe() + require.Empty(t, reqs) + for _, p := range probes { + pm.CompleteProbe(p, autonatv2.Result{AllAddrsRefused: true}, nil) + } + // all requests refused; no more probes for too many refusals + reqs = pm.GetProbe() + require.Empty(t, reqs) + + cl.Add(recentProbeInterval) + reqs = pm.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}}) + }) + + t.Run("successes", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + for j := 0; j < 2; j++ { + for i := 0; i < targetConfidence; i++ { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } + } + // all addrs confirmed + reqs := pm.GetProbe() + require.Empty(t, reqs) + + cl.Add(highConfidenceAddrProbeInterval + time.Millisecond) + reqs = nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + reqs = nextProbe(pm) + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}}) + }) + + t.Run("throttling on indeterminate reachability", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + reachability := network.ReachabilityPublic + nextReachability := func() network.Reachability { + if reachability == network.ReachabilityPublic { + reachability = network.ReachabilityPrivate + } else { + reachability = network.ReachabilityPublic + } + return reachability + } + // both addresses are indeterminate + for range 2 * maxRecentDialsPerAddr { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil) + } + reqs := pm.GetProbe() + require.Empty(t, reqs) + + cl.Add(recentProbeInterval + time.Millisecond) + reqs = pm.GetProbe() + require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}}) + for range 2 * maxRecentDialsPerAddr { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil) + } + reqs = pm.GetProbe() + require.Empty(t, reqs) + }) + + t.Run("reachabilityUpdate", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2}) + for range 2 * targetConfidence { + reqs := nextProbe(pm) + if reqs[0].Addr.Equal(pub1) { + pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } else { + pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub2, Idx: 0, Reachability: network.ReachabilityPrivate}, nil) + } + } + + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Equal(t, unreachable, []ma.Multiaddr{pub2}) + }) + t.Run("expiry", func(t *testing.T) { + pm := makeNewProbeManager([]ma.Multiaddr{pub1}) + for range 2 * targetConfidence { + reqs := nextProbe(pm) + pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } + + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) + + cl.Add(maxProbeResultTTL + 1*time.Second) + reachable, unreachable = pm.AppendConfirmedAddrs(nil, nil) + require.Empty(t, reachable) + require.Empty(t, unreachable) + }) +} + +type mockAutoNATClient struct { + F func(context.Context, []autonatv2.Request) (autonatv2.Result, error) +} + +func (m mockAutoNATClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + return m.F(ctx, reqs) +} + +var _ autonatv2Client = mockAutoNATClient{} + +func TestAddrsReachabilityTracker(t *testing.T) { + pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1") + pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1") + pri := ma.StringCast("/ip4/192.168.1.1/tcp/1") + + newTracker := func(cli mockAutoNATClient, cl clock.Clock) *addrsReachabilityTracker { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if cl == nil { + cl = clock.New() + } + tr := &addrsReachabilityTracker{ + ctx: ctx, + cancel: cancel, + client: cli, + newAddrs: make(chan []ma.Multiaddr, 1), + reachabilityUpdateCh: make(chan struct{}, 1), + maxConcurrency: 3, + newAddrsProbeDelay: 0 * time.Second, + probeManager: newProbeManager(cl.Now), + clock: cl, + } + err := tr.Start() + require.NoError(t, err) + t.Cleanup(func() { + err := tr.Close() + assert.NoError(t, err) + }) + return tr + } + + t.Run("simple", func(t *testing.T) { + // pub1 reachable, pub2 unreachable, pub3 ignored + mockClient := mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + for i, req := range reqs { + if req.Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil + } else if req.Addr.Equal(pub2) { + return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil + } + } + return autonatv2.Result{}, autonatv2.ErrNoPeers + }, + } + tr := newTracker(mockClient, nil) + tr.UpdateAddrs([]ma.Multiaddr{pub2, pub1, pri}) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(2 * time.Second): + t.Fatal("expected reachability update") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1) + require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2) + + tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pri}) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(2 * time.Second): + t.Fatal("expected reachability update") + } + reachable, unreachable = tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1) + require.Empty(t, unreachable) + }) + + t.Run("confirmed addrs ordering", func(t *testing.T) { + mockClient := mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil + }, + } + tr := newTracker(mockClient, nil) + var addrs []ma.Multiaddr + for i := 0; i < 10; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", i))) + } + slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return -a.Compare(b) }) // sort in reverse order + tr.UpdateAddrs(addrs) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(2 * time.Second): + t.Fatal("expected reachability update") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Empty(t, unreachable) + + orderedAddrs := slices.Clone(addrs) + slices.Reverse(orderedAddrs) + require.Equal(t, reachable, orderedAddrs, "%s %s", reachable, addrs) + }) + + t.Run("backoff", func(t *testing.T) { + notify := make(chan struct{}, 1) + drainNotify := func() bool { + found := false + for { + select { + case <-notify: + found = true + default: + return found + } + } + } + + var allow atomic.Bool + mockClient := mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + select { + case notify <- struct{}{}: + default: + } + if !allow.Load() { + return autonatv2.Result{}, autonatv2.ErrNoPeers + } + if reqs[0].Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil + } + return autonatv2.Result{AllAddrsRefused: true}, nil + }, + } + + cl := clock.NewMock() + tr := newTracker(mockClient, cl) + + // update addrs and wait for initial checks + tr.UpdateAddrs([]ma.Multiaddr{pub1}) + // need to update clock after the background goroutine processes the new addrs + time.Sleep(100 * time.Millisecond) + cl.Add(1) + time.Sleep(100 * time.Millisecond) + require.True(t, drainNotify()) // check that we did receive probes + + backoffInterval := backoffStartInterval + for i := 0; i < 4; i++ { + drainNotify() + cl.Add(backoffInterval / 2) + select { + case <-notify: + t.Fatal("unexpected call") + case <-time.After(50 * time.Millisecond): + } + cl.Add(backoffInterval/2 + 1) // +1 to push it slightly over the backoff interval + backoffInterval *= 2 + select { + case <-notify: + case <-time.After(1 * time.Second): + t.Fatal("expected probe") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Empty(t, reachable) + require.Empty(t, unreachable) + } + allow.Store(true) + drainNotify() + cl.Add(backoffInterval + 1) + select { + case <-tr.reachabilityUpdateCh: + case <-time.After(1 * time.Second): + t.Fatal("unexpected reachability update") + } + reachable, unreachable := tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) + }) + + t.Run("event update", func(t *testing.T) { + // allow minConfidence probes to pass + called := make(chan struct{}, minConfidence) + notify := make(chan struct{}) + mockClient := mockAutoNATClient{ + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { + select { + case called <- struct{}{}: + notify <- struct{}{} + return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil + default: + return autonatv2.Result{AllAddrsRefused: true}, nil + } + }, + } + + tr := newTracker(mockClient, nil) + tr.UpdateAddrs([]ma.Multiaddr{pub1}) + for i := 0; i < minConfidence; i++ { + select { + case <-notify: + case <-time.After(1 * time.Second): + t.Fatal("expected call to autonat client") + } + } + select { + case <-tr.reachabilityUpdateCh: + reachable, unreachable := tr.ConfirmedAddrs() + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) + case <-time.After(1 * time.Second): + t.Fatal("expected reachability update") + } + tr.UpdateAddrs([]ma.Multiaddr{pub1}) // same addrs shouldn't get update + select { + case <-tr.reachabilityUpdateCh: + t.Fatal("didn't expect reachability update") + case <-time.After(100 * time.Millisecond): + } + tr.UpdateAddrs([]ma.Multiaddr{pub2}) + select { + case <-tr.reachabilityUpdateCh: + reachable, unreachable := tr.ConfirmedAddrs() + require.Empty(t, reachable) + require.Empty(t, unreachable) + case <-time.After(1 * time.Second): + t.Fatal("expected reachability update") + } + }) + + t.Run("refresh after reset interval", func(t *testing.T) { + notify := make(chan struct{}, 1) + drainNotify := func() bool { + found := false + for { + select { + case <-notify: + found = true + default: + return found + } + } + } + + mockClient := mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + select { + case notify <- struct{}{}: + default: + } + if reqs[0].Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil + } + return autonatv2.Result{AllAddrsRefused: true}, nil + }, + } + + cl := clock.NewMock() + tr := newTracker(mockClient, cl) + + // update addrs and wait for initial checks + tr.UpdateAddrs([]ma.Multiaddr{pub1}) + // need to update clock after the background goroutine processes the new addrs + time.Sleep(100 * time.Millisecond) + cl.Add(1) + time.Sleep(100 * time.Millisecond) + require.True(t, drainNotify()) // check that we did receive probes + cl.Add(highConfidenceAddrProbeInterval / 2) + select { + case <-notify: + t.Fatal("unexpected call") + case <-time.After(50 * time.Millisecond): + } + + cl.Add(highConfidenceAddrProbeInterval/2 + defaultReachabilityRefreshInterval) // defaultResetInterval for the next probe time + select { + case <-notify: + case <-time.After(1 * time.Second): + t.Fatal("expected probe") + } + }) +} + +func TestRefreshReachability(t *testing.T) { + pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + pub2 := ma.StringCast("/ip4/1.1.1.1/tcp/2") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + newTracker := func(client autonatv2Client, pm *probeManager) *addrsReachabilityTracker { + return &addrsReachabilityTracker{ + probeManager: pm, + client: client, + clock: clock.New(), + maxConcurrency: 3, + ctx: ctx, + cancel: cancel, + } + } + t.Run("backoff on ErrNoValidPeers", func(t *testing.T) { + mockClient := mockAutoNATClient{ + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { + return autonatv2.Result{}, autonatv2.ErrNoPeers + }, + } + + addrTracker := newProbeManager(time.Now) + addrTracker.UpdateAddrs([]ma.Multiaddr{pub1}) + r := newTracker(mockClient, addrTracker) + res := r.refreshReachability() + require.True(t, <-res.BackoffCh) + require.Equal(t, addrTracker.InProgressProbes(), 0) + }) + + t.Run("returns backoff on errTooManyConsecutiveFailures", func(t *testing.T) { + // Create a client that always returns ErrDialRefused + mockClient := mockAutoNATClient{ + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { + return autonatv2.Result{}, errors.New("test error") + }, + } + + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub1}) + r := newTracker(mockClient, pm) + result := r.refreshReachability() + require.True(t, <-result.BackoffCh) + require.Equal(t, pm.InProgressProbes(), 0) + }) + + t.Run("quits on cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + block := make(chan struct{}) + mockClient := mockAutoNATClient{ + F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) { + block <- struct{}{} + return autonatv2.Result{}, nil + }, + } + + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub1}) + r := &addrsReachabilityTracker{ + ctx: ctx, + cancel: cancel, + client: mockClient, + probeManager: pm, + clock: clock.New(), + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + result := r.refreshReachability() + assert.False(t, <-result.BackoffCh) + assert.Equal(t, pm.InProgressProbes(), 0) + }() + + cancel() + time.Sleep(50 * time.Millisecond) // wait for the cancellation to be processed + + outer: + for i := 0; i < defaultMaxConcurrency; i++ { + select { + case <-block: + default: + break outer + } + } + select { + case <-block: + t.Fatal("expected no more requests") + case <-time.After(50 * time.Millisecond): + } + wg.Wait() + }) + + t.Run("handles refusals", func(t *testing.T) { + pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1") + + mockClient := mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + for i, req := range reqs { + if req.Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil + } + } + return autonatv2.Result{AllAddrsRefused: true}, nil + }, + } + + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) + r := newTracker(mockClient, pm) + + result := r.refreshReachability() + require.False(t, <-result.BackoffCh) + + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Empty(t, unreachable) + require.Equal(t, pm.InProgressProbes(), 0) + }) + + t.Run("handles completions", func(t *testing.T) { + mockClient := mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + for i, req := range reqs { + if req.Addr.Equal(pub1) { + return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil + } + if req.Addr.Equal(pub2) { + return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil + } + } + return autonatv2.Result{AllAddrsRefused: true}, nil + }, + } + pm := newProbeManager(time.Now) + pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1}) + r := newTracker(mockClient, pm) + result := r.refreshReachability() + require.False(t, <-result.BackoffCh) + + reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil) + require.Equal(t, reachable, []ma.Multiaddr{pub1}) + require.Equal(t, unreachable, []ma.Multiaddr{pub2}) + require.Equal(t, pm.InProgressProbes(), 0) + }) +} + +func TestAddrStatusProbeCount(t *testing.T) { + cases := []struct { + inputs string + wantRequiredProbes int + wantReachability network.Reachability + }{ + { + inputs: "", + wantRequiredProbes: 3, + wantReachability: network.ReachabilityUnknown, + }, + { + inputs: "S", + wantRequiredProbes: 2, + wantReachability: network.ReachabilityUnknown, + }, + { + inputs: "SS", + wantRequiredProbes: 1, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SSS", + wantRequiredProbes: 0, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SSSSSSSF", + wantRequiredProbes: 1, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SFSFSSSS", + wantRequiredProbes: 0, + wantReachability: network.ReachabilityPublic, + }, + { + inputs: "SSSSSFSF", + wantRequiredProbes: 2, + wantReachability: network.ReachabilityUnknown, + }, + { + inputs: "FF", + wantRequiredProbes: 1, + wantReachability: network.ReachabilityPrivate, + }, + } + for _, c := range cases { + t.Run(c.inputs, func(t *testing.T) { + now := time.Time{}.Add(1 * time.Second) + ao := addrStatus{} + for _, r := range c.inputs { + if r == 'S' { + ao.AddOutcome(now, network.ReachabilityPublic, 5) + } else { + ao.AddOutcome(now, network.ReachabilityPrivate, 5) + } + now = now.Add(1 * time.Second) + } + require.Equal(t, ao.RequiredProbeCount(now), c.wantRequiredProbes) + require.Equal(t, ao.Reachability(), c.wantReachability) + if c.wantRequiredProbes == 0 { + now = now.Add(highConfidenceAddrProbeInterval + 10*time.Microsecond) + require.Equal(t, ao.RequiredProbeCount(now), 1) + } + + now = now.Add(1 * time.Second) + ao.RemoveBefore(now) + require.Len(t, ao.outcomes, 0) + }) + } +} + +func BenchmarkAddrTracker(b *testing.B) { + cl := clock.NewMock() + t := newProbeManager(cl.Now) + + addrs := make([]ma.Multiaddr, 20) + for i := range addrs { + addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", rand.Intn(1000))) + } + t.UpdateAddrs(addrs) + b.ReportAllocs() + b.ResetTimer() + p := t.GetProbe() + for i := 0; i < b.N; i++ { + pp := t.GetProbe() + if len(pp) == 0 { + pp = p + } + t.MarkProbeInProgress(pp) + t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil) + } +} + +func FuzzAddrsReachabilityTracker(f *testing.F) { + type autonatv2Response struct { + Result autonatv2.Result + Err error + } + + newMockClient := func(b []byte) mockAutoNATClient { + count := 0 + return mockAutoNATClient{ + F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) { + if len(b) == 0 { + return autonatv2.Result{}, nil + } + count = (count + 1) % len(b) + if b[count]%3 == 0 { + // some address confirmed + c1 := (count + 1) % len(b) + c2 := (count + 2) % len(b) + rch := network.Reachability(b[c1] % 3) + n := int(b[c2]) % len(reqs) + return autonatv2.Result{ + Addr: reqs[n].Addr, + Idx: n, + Reachability: rch, + }, nil + } + outcomes := []autonatv2Response{ + {Result: autonatv2.Result{AllAddrsRefused: true}}, + {Err: errors.New("test error")}, + {Err: autonatv2.ErrPrivateAddrs}, + {Err: autonatv2.ErrNoPeers}, + {Result: autonatv2.Result{}, Err: nil}, + {Result: autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}}, + {Result: autonatv2.Result{ + Addr: reqs[0].Addr, + Idx: 0, + Reachability: network.ReachabilityPublic, + AllAddrsRefused: true, + }}, + {Result: autonatv2.Result{ + Addr: reqs[0].Addr, + Idx: len(reqs) - 1, // invalid idx + Reachability: network.ReachabilityPublic, + AllAddrsRefused: false, + }}, + } + outcome := outcomes[int(b[count])%len(outcomes)] + return outcome.Result, outcome.Err + }, + } + } + + // TODO: Move this to go-multiaddrs + getProto := func(protos []byte) ma.Multiaddr { + protoType := 0 + if len(protos) > 0 { + protoType = int(protos[0]) + } + + port1, port2 := 0, 0 + if len(protos) > 1 { + port1 = int(protos[1]) + } + if len(protos) > 2 { + port2 = int(protos[2]) + } + protoTemplates := []string{ + "/tcp/%d/", + "/udp/%d/", + "/udp/%d/quic-v1/", + "/udp/%d/quic-v1/tcp/%d", + "/udp/%d/quic-v1/webtransport/", + "/udp/%d/webrtc/", + "/udp/%d/webrtc-direct/", + "/unix/hello/", + } + s := protoTemplates[protoType%len(protoTemplates)] + port1 %= (1 << 16) + if strings.Count(s, "%d") == 1 { + return ma.StringCast(fmt.Sprintf(s, port1)) + } + port2 %= (1 << 16) + return ma.StringCast(fmt.Sprintf(s, port1, port2)) + } + + getIP := func(ips []byte) ma.Multiaddr { + ipType := 0 + if len(ips) > 0 { + ipType = int(ips[0]) + } + ips = ips[1:] + var x, y int64 + split := 128 / 8 + if len(ips) < split { + split = len(ips) + } + var b [8]byte + copy(b[:], ips[:split]) + x = int64(binary.LittleEndian.Uint64(b[:])) + clear(b[:]) + copy(b[:], ips[split:]) + y = int64(binary.LittleEndian.Uint64(b[:])) + + var ip netip.Addr + switch ipType % 3 { + case 0: + ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)}) + return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip)) + case 1: + pubIP := net.ParseIP("2005::") // Public IP address + x := int64(binary.LittleEndian.Uint64(pubIP[0:8])) + ip = netip.AddrFrom16([16]byte{ + byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24), + byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56), + byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24), + byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56), + }) + return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip)) + default: + ip := netip.AddrFrom16([16]byte{ + byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24), + byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56), + byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24), + byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56), + }) + return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip)) + } + } + + getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr { + switch addrType % 4 { + case 0: + return getIP(ips).Encapsulate(getProto(protos)) + case 1: + return getProto(protos) + case 2: + return nil + default: + return getIP(ips).Encapsulate(getProto(protos)) + } + } + + getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr { + hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "") + hostName = strings.ReplaceAll(hostName, "/", "") + if hostName == "" { + hostName = "localhost" + } + dnsType := 0 + if len(hostNameBytes) > 0 { + dnsType = int(hostNameBytes[0]) + } + dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} + da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName)) + return da.Encapsulate(getProto(protos)) + } + + const maxAddrs = 1000 + getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr { + if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 { + return nil + } + numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs + addrs := make([]ma.Multiaddr, numAddrs) + ipIdx := 0 + protoIdx := 0 + for i := range numAddrs { + addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:]) + ipIdx = (ipIdx + 1) % len(ips) + protoIdx = (protoIdx + 1) % len(protos) + } + maxDNSAddrs := 10 + protoIdx = 0 + for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 { + ed := min(i+2, len(hostNames)) + addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:])) + protoIdx = (protoIdx + 1) % len(protos) + } + return addrs + } + + cl := clock.NewMock() + f.Fuzz(func(t *testing.T, numAddrs int, ips, protos, hostNames, autonatResponses []byte) { + tr := newAddrsReachabilityTracker(newMockClient(autonatResponses), nil, cl) + require.NoError(t, tr.Start()) + tr.UpdateAddrs(getAddrs(numAddrs, ips, protos, hostNames)) + + // fuzz tests need to finish in 10 seconds for some reason + // https://github.com/golang/go/issues/48157 + // https://github.com/golang/go/commit/5d24203c394e6b64c42a9f69b990d94cb6c8aad4#diff-4e3b9481b8794eb058998e2bec389d3db7a23c54e67ac0f7259a3a5d2c79fd04R474-R483 + const maxIters = 20 + for range maxIters { + cl.Add(5 * time.Minute) + time.Sleep(100 * time.Millisecond) + } + require.NoError(t, tr.Close()) + }) +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index b4db8c2091..d7b9f40ab2 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -156,8 +156,8 @@ type HostOpts struct { // DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify DisableIdentifyAddressDiscovery bool - EnableAutoNATv2 bool - AutoNATv2Dialer host.Host + + AutoNATv2 *autonatv2.AutoNAT } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. @@ -236,7 +236,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { }); ok { tfl = s.TransportForListening } - h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan) + + if opts.AutoNATv2 != nil { + h.autonatv2 = opts.AutoNATv2 + } + + var autonatv2Client autonatv2Client // avoid typed nil errors + if h.autonatv2 != nil { + autonatv2Client = h.autonatv2 + } + h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan, autonatv2Client) if err != nil { return nil, fmt.Errorf("failed to create address service: %w", err) } @@ -283,17 +292,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h.pings = ping.NewPingService(h) } - if opts.EnableAutoNATv2 { - var mt autonatv2.MetricsTracer - if opts.EnableMetrics { - mt = autonatv2.NewMetricsTracer(opts.PrometheusRegisterer) - } - h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer, autonatv2.WithMetricsTracer(mt)) - if err != nil { - return nil, fmt.Errorf("failed to create autonatv2: %w", err) - } - } - if !h.disableSignedPeerRecord { h.signKey = h.Peerstore().PrivKey(h.ID()) cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) @@ -320,7 +318,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { func (h *BasicHost) Start() { h.psManager.Start() if h.autonatv2 != nil { - err := h.autonatv2.Start() + err := h.autonatv2.Start(h) if err != nil { log.Errorf("autonat v2 failed to start: %s", err) } @@ -754,6 +752,16 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { return h.addressManager.DirectAddrs() } +// ReachableAddrs returns all addresses of the host that are reachable from the internet +// as verified by autonatv2. +// +// Experimental: This API may change in the future without deprecation. +// +// Requires AutoNATv2 to be enabled. +func (h *BasicHost) ReachableAddrs() []ma.Multiaddr { + return h.addressManager.ReachableAddrs() +} + func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr { totalSize := 0 for _, a := range addrs { @@ -836,7 +844,6 @@ func (h *BasicHost) Close() error { if h.cmgr != nil { h.cmgr.Close() } - h.addressManager.Close() if h.ids != nil { diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 80a1d3dd62..757d43f27b 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -47,6 +47,7 @@ func TestHostSimple(t *testing.T) { h1.Start() h2, err := NewHost(swarmt.GenSwarm(t), nil) require.NoError(t, err) + defer h2.Close() h2.Start() @@ -211,6 +212,7 @@ func TestAllAddrs(t *testing.T) { // no listen addrs h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil) require.NoError(t, err) + h.Start() defer h.Close() require.Nil(t, h.AllAddrs()) diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index ee4318dd87..866c02381f 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -4,18 +4,17 @@ import ( "context" "errors" "fmt" + "iter" + "math/rand/v2" "slices" "sync" "time" - "math/rand/v2" - logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -35,11 +34,15 @@ const ( // maxPeerAddresses is the number of addresses in a dial request the server // will inspect, rest are ignored. maxPeerAddresses = 50 + + defaultThrottlePeerDuration = 2 * time.Minute ) var ( - ErrNoValidPeers = errors.New("no valid peers for autonat v2") - ErrDialRefused = errors.New("dial refused") + // ErrNoPeers is returned when the client knows no autonatv2 servers. + ErrNoPeers = errors.New("no peers for autonat v2") + // ErrPrivateAddrs is returned when the request has private IP addresses. + ErrPrivateAddrs = errors.New("private addresses cannot be verified with autonatv2") log = logging.Logger("autonatv2") ) @@ -56,10 +59,12 @@ type Request struct { type Result struct { // Addr is the dialed address Addr ma.Multiaddr - // Reachability of the dialed address + // Idx is the index of the address that was dialed + Idx int + // Reachability is the reachability for `Addr` Reachability network.Reachability - // Status is the outcome of the dialback - Status pb.DialStatus + // AllAddrsRefused is true when the server refused to dial all the addresses in the request. + AllAddrsRefused bool } // AutoNAT implements the AutoNAT v2 client and server. @@ -76,8 +81,12 @@ type AutoNAT struct { srv *server cli *client - mx sync.Mutex - peers *peersMap + mx sync.Mutex + peers *peersMap + throttlePeer map[peer.ID]time.Time + // throttlePeerDuration is the duration to wait before making another dial request to the + // same server. + throttlePeerDuration time.Duration // allowPrivateAddrs enables using private and localhost addresses for reachability checks. // This is only useful for testing. allowPrivateAddrs bool @@ -86,7 +95,7 @@ type AutoNAT struct { // New returns a new AutoNAT instance. // host and dialerHost should have the same dialing capabilities. In case the host doesn't support // a transport, dial back requests for address for that transport will be ignored. -func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { +func New(dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { s := defaultSettings() for _, o := range opts { if err := o(s); err != nil { @@ -96,18 +105,20 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, ctx, cancel := context.WithCancel(context.Background()) an := &AutoNAT{ - host: host, - ctx: ctx, - cancel: cancel, - srv: newServer(host, dialerHost, s), - cli: newClient(host), - allowPrivateAddrs: s.allowPrivateAddrs, - peers: newPeersMap(), + ctx: ctx, + cancel: cancel, + srv: newServer(dialerHost, s), + cli: newClient(), + allowPrivateAddrs: s.allowPrivateAddrs, + peers: newPeersMap(), + throttlePeer: make(map[peer.ID]time.Time), + throttlePeerDuration: s.throttlePeerDuration, } return an, nil } func (an *AutoNAT) background(sub event.Subscription) { + ticker := time.NewTicker(10 * time.Minute) for { select { case <-an.ctx.Done(): @@ -122,12 +133,24 @@ func (an *AutoNAT) background(sub event.Subscription) { an.updatePeer(evt.Peer) case event.EvtPeerIdentificationCompleted: an.updatePeer(evt.Peer) + default: + log.Errorf("unexpected event: %T", e) } + case <-ticker.C: + now := time.Now() + an.mx.Lock() + for p, t := range an.throttlePeer { + if t.Before(now) { + delete(an.throttlePeer, p) + } + } + an.mx.Unlock() } } } -func (an *AutoNAT) Start() error { +func (an *AutoNAT) Start(h host.Host) error { + an.host = h // Listen on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged // event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers. sub, err := an.host.EventBus().Subscribe([]interface{}{ @@ -138,8 +161,8 @@ func (an *AutoNAT) Start() error { if err != nil { return fmt.Errorf("event subscription failed: %w", err) } - an.cli.Start() - an.srv.Start() + an.cli.Start(h) + an.srv.Start(h) an.wg.Add(1) go an.background(sub) @@ -156,24 +179,48 @@ func (an *AutoNAT) Close() { // GetReachability makes a single dial request for checking reachability for requested addresses func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) { + var filteredReqs []Request if !an.allowPrivateAddrs { + filteredReqs = make([]Request, 0, len(reqs)) for _, r := range reqs { - if !manet.IsPublicAddr(r.Addr) { - return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr) + if manet.IsPublicAddr(r.Addr) { + filteredReqs = append(filteredReqs, r) + } else { + log.Errorf("private address in reachability check: %s", r.Addr) } } + if len(filteredReqs) == 0 { + return Result{}, ErrPrivateAddrs + } + } else { + filteredReqs = reqs } an.mx.Lock() - p := an.peers.GetRand() + now := time.Now() + var p peer.ID + for pr := range an.peers.Shuffled() { + if t := an.throttlePeer[pr]; t.After(now) { + continue + } + p = pr + an.throttlePeer[p] = time.Now().Add(an.throttlePeerDuration) + break + } an.mx.Unlock() if p == "" { - return Result{}, ErrNoValidPeers + return Result{}, ErrNoPeers } - - res, err := an.cli.GetReachability(ctx, p, reqs) + res, err := an.cli.GetReachability(ctx, p, filteredReqs) if err != nil { log.Debugf("reachability check with %s failed, err: %s", p, err) - return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err) + return res, fmt.Errorf("reachability check with %s failed: %w", p, err) + } + // restore the correct index in case we'd filtered private addresses + for i, r := range reqs { + if r.Addr.Equal(res.Addr) { + res.Idx = i + break + } } log.Debugf("reachability check with %s successful", p) return res, nil @@ -187,7 +234,7 @@ func (an *AutoNAT) updatePeer(p peer.ID) { // and swarm for the current state protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) connectedness := an.host.Network().Connectedness(p) - if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected { + if err == nil && connectedness == network.Connected && slices.Contains(protos, DialProtocol) { an.peers.Put(p) } else { an.peers.Delete(p) @@ -208,28 +255,40 @@ func newPeersMap() *peersMap { } } -func (p *peersMap) GetRand() peer.ID { - if len(p.peers) == 0 { - return "" +// Shuffled iterates over the map in random order +func (p *peersMap) Shuffled() iter.Seq[peer.ID] { + n := len(p.peers) + start := 0 + if n > 0 { + start = rand.IntN(n) + } + return func(yield func(peer.ID) bool) { + for i := range n { + if !yield(p.peers[(i+start)%n]) { + return + } + } } - return p.peers[rand.IntN(len(p.peers))] } -func (p *peersMap) Put(pid peer.ID) { - if _, ok := p.peerIdx[pid]; ok { +func (p *peersMap) Put(id peer.ID) { + if _, ok := p.peerIdx[id]; ok { return } - p.peers = append(p.peers, pid) - p.peerIdx[pid] = len(p.peers) - 1 + p.peers = append(p.peers, id) + p.peerIdx[id] = len(p.peers) - 1 } -func (p *peersMap) Delete(pid peer.ID) { - idx, ok := p.peerIdx[pid] +func (p *peersMap) Delete(id peer.ID) { + idx, ok := p.peerIdx[id] if !ok { return } - p.peers[idx] = p.peers[len(p.peers)-1] - p.peerIdx[p.peers[idx]] = idx - p.peers = p.peers[:len(p.peers)-1] - delete(p.peerIdx, pid) + n := len(p.peers) + lastPeer := p.peers[n-1] + p.peers[idx] = lastPeer + p.peerIdx[lastPeer] = idx + p.peers[n-1] = "" + p.peers = p.peers[:n-1] + delete(p.peerIdx, id) } diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index cee161de1a..097ac98a23 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -2,8 +2,13 @@ package autonatv2 import ( "context" + "encoding/binary" "errors" "fmt" + "math" + "net" + "net/netip" + "strings" "sync/atomic" "testing" "time" @@ -36,11 +41,12 @@ func newAutoNAT(t testing.TB, dialer host.Host, opts ...AutoNATOption) *AutoNAT swarm.WithUDPBlackHoleSuccessCounter(nil), swarm.WithIPv6BlackHoleSuccessCounter(nil)))) } - an, err := New(h, dialer, opts...) + opts = append([]AutoNATOption{withThrottlePeerDuration(0)}, opts...) + an, err := New(dialer, opts...) if err != nil { t.Error(err) } - an.Start() + require.NoError(t, an.Start(h)) t.Cleanup(an.Close) return an } @@ -74,7 +80,7 @@ func waitForPeer(t testing.TB, a *AutoNAT) { require.Eventually(t, func() bool { a.mx.Lock() defer a.mx.Unlock() - return a.peers.GetRand() != "" + return len(a.peers.peers) != 0 }, 5*time.Second, 100*time.Millisecond) } @@ -88,7 +94,7 @@ func TestAutoNATPrivateAddr(t *testing.T) { an := newAutoNAT(t, nil) res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) require.Equal(t, res, Result{}) - require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") + require.ErrorIs(t, err, ErrPrivateAddrs) } func TestClientRequest(t *testing.T) { @@ -154,19 +160,6 @@ func TestClientServerError(t *testing.T) { }, errorStr: "invalid msg type", }, - { - handler: func(s network.Stream) { - w := pbio.NewDelimitedWriter(s) - assert.NoError(t, w.WriteMsg( - &pb.Message{Msg: &pb.Message_DialResponse{ - DialResponse: &pb.DialResponse{ - Status: pb.DialResponse_E_DIAL_REFUSED, - }, - }}, - )) - }, - errorStr: ErrDialRefused.Error(), - }, } for i, tc := range tests { @@ -298,6 +291,49 @@ func TestClientDataRequest(t *testing.T) { } } +func TestAutoNATPrivateAndPublicAddrs(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer dialerHost.Close() + handler := func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_E_DIAL_ERROR, + AddrIdx: 0, + }, + }, + }) + s.Close() + } + + b.SetStreamHandler(DialProtocol, handler) + privateAddr := ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1") + publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/10/quic-v1") + res, err := an.GetReachability(context.Background(), + []Request{ + {Addr: privateAddr}, + {Addr: publicAddr}, + }) + require.NoError(t, err) + require.Equal(t, res.Addr, publicAddr, "%s\n%s", res.Addr, publicAddr) + require.Equal(t, res.Idx, 1) + require.Equal(t, res.Reachability, network.ReachabilityPrivate) +} + func TestClientDialBacks(t *testing.T) { an := newAutoNAT(t, nil, allowPrivateAddrs) defer an.Close() @@ -507,7 +543,6 @@ func TestClientDialBacks(t *testing.T) { } else { require.NoError(t, err) require.Equal(t, res.Reachability, network.ReachabilityPublic) - require.Equal(t, res.Status, pb.DialStatus_OK) } }) } @@ -551,46 +586,6 @@ func TestEventSubscription(t *testing.T) { }, 5*time.Second, 100*time.Millisecond) } -func TestPeersMap(t *testing.T) { - emptyPeerID := peer.ID("") - - t.Run("single_item", func(t *testing.T) { - p := newPeersMap() - p.Put("peer1") - p.Delete("peer1") - p.Put("peer1") - require.Equal(t, peer.ID("peer1"), p.GetRand()) - p.Delete("peer1") - require.Equal(t, emptyPeerID, p.GetRand()) - }) - - t.Run("multiple_items", func(t *testing.T) { - p := newPeersMap() - require.Equal(t, emptyPeerID, p.GetRand()) - - allPeers := make(map[peer.ID]bool) - for i := 0; i < 20; i++ { - pid := peer.ID(fmt.Sprintf("peer-%d", i)) - allPeers[pid] = true - p.Put(pid) - } - foundPeers := make(map[peer.ID]bool) - for i := 0; i < 1000; i++ { - pid := p.GetRand() - require.NotEqual(t, emptyPeerID, p) - require.True(t, allPeers[pid]) - foundPeers[pid] = true - if len(foundPeers) == len(allPeers) { - break - } - } - for pid := range allPeers { - p.Delete(pid) - } - require.Equal(t, emptyPeerID, p.GetRand()) - }) -} - func TestAreAddrsConsistency(t *testing.T) { c := &client{ normalizeMultiaddr: func(a ma.Multiaddr) ma.Multiaddr { @@ -645,6 +640,12 @@ func TestAreAddrsConsistency(t *testing.T) { dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"), success: false, }, + { + name: "dns6", + localAddr: ma.StringCast("/dns6/lib.p2p/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/"), + success: false, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { @@ -658,3 +659,173 @@ func TestAreAddrsConsistency(t *testing.T) { }) } } + +func TestPeerMap(t *testing.T) { + pm := newPeersMap() + // Add 1, 2, 3 + pm.Put(peer.ID("1")) + pm.Put(peer.ID("2")) + pm.Put(peer.ID("3")) + + // Remove 3, 2 + pm.Delete(peer.ID("3")) + pm.Delete(peer.ID("2")) + + // Add 4 + pm.Put(peer.ID("4")) + + // Remove 3, 2 again. Should be no op + pm.Delete(peer.ID("3")) + pm.Delete(peer.ID("2")) + + contains := []peer.ID{"1", "4"} + elems := make([]peer.ID, 0) + for p := range pm.Shuffled() { + elems = append(elems, p) + } + require.ElementsMatch(t, contains, elems) +} + +func FuzzClient(f *testing.F) { + a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2)) + c := newAutoNAT(f, nil) + idAndWait(f, c, a) + + // TODO: Move this to go-multiaddrs + getProto := func(protos []byte) ma.Multiaddr { + protoType := 0 + if len(protos) > 0 { + protoType = int(protos[0]) + } + + port1, port2 := 0, 0 + if len(protos) > 1 { + port1 = int(protos[1]) + } + if len(protos) > 2 { + port2 = int(protos[2]) + } + protoTemplates := []string{ + "/tcp/%d/", + "/udp/%d/", + "/udp/%d/quic-v1/", + "/udp/%d/quic-v1/tcp/%d", + "/udp/%d/quic-v1/webtransport/", + "/udp/%d/webrtc/", + "/udp/%d/webrtc-direct/", + "/unix/hello/", + } + s := protoTemplates[protoType%len(protoTemplates)] + port1 %= (1 << 16) + if strings.Count(s, "%d") == 1 { + return ma.StringCast(fmt.Sprintf(s, port1)) + } + port2 %= (1 << 16) + return ma.StringCast(fmt.Sprintf(s, port1, port2)) + } + + getIP := func(ips []byte) ma.Multiaddr { + ipType := 0 + if len(ips) > 0 { + ipType = int(ips[0]) + } + ips = ips[1:] + var x, y int64 + split := 128 / 8 + if len(ips) < split { + split = len(ips) + } + var b [8]byte + copy(b[:], ips[:split]) + x = int64(binary.LittleEndian.Uint64(b[:])) + clear(b[:]) + copy(b[:], ips[split:]) + y = int64(binary.LittleEndian.Uint64(b[:])) + + var ip netip.Addr + switch ipType % 3 { + case 0: + ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)}) + return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip)) + case 1: + pubIP := net.ParseIP("2005::") // Public IP address + x := int64(binary.LittleEndian.Uint64(pubIP[0:8])) + ip = netip.AddrFrom16([16]byte{ + byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24), + byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56), + byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24), + byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56), + }) + return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip)) + default: + ip := netip.AddrFrom16([16]byte{ + byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24), + byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56), + byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24), + byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56), + }) + return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip)) + } + } + + getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr { + switch addrType % 4 { + case 0: + return getIP(ips).Encapsulate(getProto(protos)) + case 1: + return getProto(protos) + case 2: + return nil + default: + return getIP(ips).Encapsulate(getProto(protos)) + } + } + + getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr { + hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "") + hostName = strings.ReplaceAll(hostName, "/", "") + if hostName == "" { + hostName = "localhost" + } + dnsType := 0 + if len(hostNameBytes) > 0 { + dnsType = int(hostNameBytes[0]) + } + dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"} + da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName)) + return da.Encapsulate(getProto(protos)) + } + + const maxAddrs = 100 + getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr { + if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 { + return nil + } + numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs + addrs := make([]ma.Multiaddr, numAddrs) + ipIdx := 0 + protoIdx := 0 + for i := range numAddrs { + addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:]) + ipIdx = (ipIdx + 1) % len(ips) + protoIdx = (protoIdx + 1) % len(protos) + } + maxDNSAddrs := 10 + protoIdx = 0 + for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 { + ed := min(i+2, len(hostNames)) + addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:])) + protoIdx = (protoIdx + 1) % len(protos) + } + return addrs + } + // reduce the streamTimeout before running this. TODO: fix this + f.Fuzz(func(_ *testing.T, numAddrs int, ips, protos, hostNames []byte) { + addrs := getAddrs(numAddrs, ips, protos, hostNames) + reqs := make([]Request, len(addrs)) + for i, addr := range addrs { + reqs[i] = Request{Addr: addr, SendDialData: true} + } + c.GetReachability(context.Background(), reqs) + }) +} diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index bbb6145b8c..7cd0dba5f0 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -35,20 +35,20 @@ type normalizeMultiaddrer interface { NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr } -func newClient(h host.Host) *client { - normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a } - if hn, ok := h.(normalizeMultiaddrer); ok { - normalizeMultiaddr = hn.NormalizeMultiaddr - } +func newClient() *client { return &client{ - host: h, - dialData: make([]byte, 4000), - normalizeMultiaddr: normalizeMultiaddr, - dialBackQueues: make(map[uint64]chan ma.Multiaddr), + dialData: make([]byte, 4000), + dialBackQueues: make(map[uint64]chan ma.Multiaddr), } } -func (ac *client) Start() { +func (ac *client) Start(h host.Host) { + normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a } + if hn, ok := h.(normalizeMultiaddrer); ok { + normalizeMultiaddr = hn.NormalizeMultiaddr + } + ac.host = h + ac.normalizeMultiaddr = normalizeMultiaddr ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack) } @@ -109,9 +109,9 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request break // provide dial data if appropriate case msg.GetDialDataRequest() != nil: - if err := ac.validateDialDataRequest(reqs, &msg); err != nil { + if err := validateDialDataRequest(reqs, &msg); err != nil { s.Reset() - return Result{}, fmt.Errorf("invalid dial data request: %w", err) + return Result{}, fmt.Errorf("invalid dial data request: %s %w", s.Conn().RemoteMultiaddr(), err) } // dial data request is valid and we want to send data if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil { @@ -136,7 +136,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request // E_DIAL_REFUSED has implication for deciding future address verificiation priorities // wrap a distinct error for convenient errors.Is usage if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED { - return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused) + return Result{AllAddrsRefused: true}, nil } return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(), pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())]) @@ -147,7 +147,6 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request if int(resp.AddrIdx) >= len(reqs) { return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs)) } - // wait for nonce from the server var dialBackAddr ma.Multiaddr if resp.GetDialStatus() == pb.DialStatus_OK { @@ -163,7 +162,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request return ac.newResult(resp, reqs, dialBackAddr) } -func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error { +func validateDialDataRequest(reqs []Request, msg *pb.Message) error { idx := int(msg.GetDialDataRequest().AddrIdx) if idx >= len(reqs) { // invalid address index return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs)) @@ -179,9 +178,13 @@ func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) { idx := int(resp.AddrIdx) + if idx >= len(reqs) { + // This should have been validated by this point, but checking this is cheap. + return Result{}, fmt.Errorf("addrs index(%d) greater than len(reqs)(%d)", idx, len(reqs)) + } addr := reqs[idx].Addr - var rch network.Reachability + rch := network.ReachabilityUnknown //nolint:ineffassign switch resp.DialStatus { case pb.DialStatus_OK: if !ac.areAddrsConsistent(dialBackAddr, addr) { @@ -191,17 +194,16 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr) } rch = network.ReachabilityPublic - case pb.DialStatus_E_DIAL_ERROR: - rch = network.ReachabilityPrivate case pb.DialStatus_E_DIAL_BACK_ERROR: - if ac.areAddrsConsistent(dialBackAddr, addr) { - // We received the dial back but the server claims the dial back errored. - // As long as we received the correct nonce in dial back it is safe to assume - // that we are public. - rch = network.ReachabilityPublic - } else { - rch = network.ReachabilityUnknown + if !ac.areAddrsConsistent(dialBackAddr, addr) { + return Result{}, fmt.Errorf("dial-back stream error: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr) } + // We received the dial back but the server claims the dial back errored. + // As long as we received the correct nonce in dial back it is safe to assume + // that we are public. + rch = network.ReachabilityPublic + case pb.DialStatus_E_DIAL_ERROR: + rch = network.ReachabilityPrivate default: // Unexpected response code. Discard the response and fail. log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus) @@ -210,8 +212,8 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr return Result{ Addr: addr, + Idx: idx, Reachability: rch, - Status: resp.DialStatus, }, nil } @@ -307,7 +309,7 @@ func (ac *client) handleDialBack(s network.Stream) { } func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool { - if connLocalAddr == nil || dialedAddr == nil { + if len(connLocalAddr) == 0 || len(dialedAddr) == 0 { return false } connLocalAddr = ac.normalizeMultiaddr(connLocalAddr) @@ -318,32 +320,31 @@ func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) boo if len(localProtos) != len(externalProtos) { return false } - for i := 0; i < len(localProtos); i++ { + for i, lp := range localProtos { + ep := externalProtos[i] if i == 0 { - switch externalProtos[i].Code { + switch ep.Code { case ma.P_DNS, ma.P_DNSADDR: - if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 { + if lp.Code == ma.P_IP4 || lp.Code == ma.P_IP6 { continue } return false case ma.P_DNS4: - if localProtos[i].Code == ma.P_IP4 { + if lp.Code == ma.P_IP4 { continue } return false case ma.P_DNS6: - if localProtos[i].Code == ma.P_IP6 { + if lp.Code == ma.P_IP6 { continue } return false } - if localProtos[i].Code != externalProtos[i].Code { - return false - } - } else { - if localProtos[i].Code != externalProtos[i].Code { + if lp.Code != ep.Code { return false } + } else if lp.Code != ep.Code { + return false } } return true diff --git a/p2p/protocol/autonatv2/options.go b/p2p/protocol/autonatv2/options.go index 76cd3735f4..f7cf4b7178 100644 --- a/p2p/protocol/autonatv2/options.go +++ b/p2p/protocol/autonatv2/options.go @@ -13,6 +13,7 @@ type autoNATSettings struct { now func() time.Time amplificatonAttackPreventionDialWait time.Duration metricsTracer MetricsTracer + throttlePeerDuration time.Duration } func defaultSettings() *autoNATSettings { @@ -25,6 +26,7 @@ func defaultSettings() *autoNATSettings { dataRequestPolicy: amplificationAttackPrevention, amplificatonAttackPreventionDialWait: 3 * time.Second, now: time.Now, + throttlePeerDuration: defaultThrottlePeerDuration, } } @@ -65,3 +67,10 @@ func withAmplificationAttackPreventionDialWait(d time.Duration) AutoNATOption { return nil } } + +func withThrottlePeerDuration(d time.Duration) AutoNATOption { + return func(s *autoNATSettings) error { + s.throttlePeerDuration = d + return nil + } +} diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index bb10e1e4d7..a8437a2df9 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -59,10 +59,9 @@ type server struct { allowPrivateAddrs bool } -func newServer(host, dialer host.Host, s *autoNATSettings) *server { +func newServer(dialer host.Host, s *autoNATSettings) *server { return &server{ dialerHost: dialer, - host: host, dialDataRequestPolicy: s.dataRequestPolicy, amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait, allowPrivateAddrs: s.allowPrivateAddrs, @@ -79,7 +78,8 @@ func newServer(host, dialer host.Host, s *autoNATSettings) *server { } // Enable attaches the stream handler to the host. -func (as *server) Start() { +func (as *server) Start(h host.Host) { + as.host = h as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) } diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index c65aa5b880..7447fc8a4f 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "math" + "strings" "sync" "sync/atomic" "testing" @@ -46,8 +47,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("black holed addr", func(t *testing.T) { @@ -64,8 +65,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"), SendDialData: true, }}) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("private addrs", func(t *testing.T) { @@ -76,8 +77,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("relay addrs", func(t *testing.T) { @@ -89,8 +90,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { res, err := c.GetReachability(context.Background(), newTestRequests( []ma.Multiaddr{ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p/%s/p2p-circuit/p2p/%s", c.host.ID(), c.srv.dialerHost.ID()))}, true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("no addr", func(t *testing.T) { @@ -113,8 +114,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) { idAndWait(t, c, an) res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + require.NoError(t, err) + require.Equal(t, Result{AllAddrsRefused: true}, res) }) t.Run("msg too large", func(t *testing.T) { @@ -135,7 +136,6 @@ func TestServerInvalidAddrsRejected(t *testing.T) { require.ErrorIs(t, err, network.ErrReset) require.Equal(t, Result{}, res) }) - } func TestServerDataRequest(t *testing.T) { @@ -178,8 +178,8 @@ func TestServerDataRequest(t *testing.T) { require.Equal(t, Result{ Addr: quicAddr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) // Small messages should be rejected for dial data @@ -191,14 +191,11 @@ func TestServerDataRequest(t *testing.T) { func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { const concurrentRequests = 5 - // server will skip all tcp addresses - dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) - - doneChan := make(chan struct{}) - an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( + stallChan := make(chan struct{}) + an := newAutoNAT(t, nil, allowPrivateAddrs, withDataRequestPolicy( // stall all allowed requests func(_, dialAddr ma.Multiaddr) bool { - <-doneChan + <-stallChan return true }), WithServerRateLimit(10, 10, 10, concurrentRequests), @@ -207,16 +204,18 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { defer an.Close() defer an.host.Close() - c := newAutoNAT(t, nil, allowPrivateAddrs) + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + c := newAutoNAT(t, dialer, allowPrivateAddrs) defer c.Close() defer c.host.Close() idAndWait(t, c, an) errChan := make(chan error) - const N = 10 - // num concurrentRequests will stall and N will fail - for i := 0; i < concurrentRequests+N; i++ { + const n = 10 + // num concurrentRequests will stall and n will fail + for i := 0; i < concurrentRequests+n; i++ { go func() { _, err := c.GetReachability(context.Background(), []Request{{Addr: c.host.Addrs()[0], SendDialData: false}}) errChan <- err @@ -224,17 +223,20 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { } // check N failures - for i := 0; i < N; i++ { + for i := 0; i < n; i++ { select { case err := <-errChan: require.Error(t, err) + if !strings.Contains(err.Error(), "stream reset") && !strings.Contains(err.Error(), "E_REQUEST_REJECTED") { + t.Fatalf("invalid error: %s expected: stream reset or E_REQUEST_REJECTED", err) + } case <-time.After(10 * time.Second): - t.Fatalf("expected %d errors: got: %d", N, i) + t.Fatalf("expected %d errors: got: %d", n, i) } } + close(stallChan) // complete stalled requests // check concurrentRequests failures, as we won't send dial data - close(doneChan) for i := 0; i < concurrentRequests; i++ { select { case err := <-errChan: @@ -290,8 +292,8 @@ func TestServerDataRequestJitter(t *testing.T) { require.Equal(t, Result{ Addr: quicAddr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) if took > 500*time.Millisecond { return @@ -320,8 +322,8 @@ func TestServerDial(t *testing.T) { require.NoError(t, err) require.Equal(t, Result{ Addr: unreachableAddr, + Idx: 0, Reachability: network.ReachabilityPrivate, - Status: pb.DialStatus_E_DIAL_ERROR, }, res) }) @@ -330,16 +332,16 @@ func TestServerDial(t *testing.T) { require.NoError(t, err) require.Equal(t, Result{ Addr: hostAddrs[0], + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) for _, addr := range c.host.Addrs() { res, err := c.GetReachability(context.Background(), newTestRequests([]ma.Multiaddr{addr}, false)) require.NoError(t, err) require.Equal(t, Result{ Addr: addr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) } }) @@ -347,12 +349,8 @@ func TestServerDial(t *testing.T) { t.Run("dialback error", func(t *testing.T) { c.host.RemoveStreamHandler(DialBackProtocol) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) - require.NoError(t, err) - require.Equal(t, Result{ - Addr: hostAddrs[0], - Reachability: network.ReachabilityUnknown, - Status: pb.DialStatus_E_DIAL_BACK_ERROR, - }, res) + require.ErrorContains(t, err, "dial-back stream error") + require.Equal(t, Result{}, res) }) } @@ -396,7 +394,6 @@ func TestRateLimiter(t *testing.T) { cl.AdvanceBy(10 * time.Second) require.True(t, r.Accept("peer3")) - } func TestRateLimiterConcurrentRequests(t *testing.T) { @@ -558,22 +555,23 @@ func TestServerDataRequestWithAmplificationAttackPrevention(t *testing.T) { require.NoError(t, err) require.Equal(t, Result{ Addr: quicv4Addr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) // ipv6 address should require dial data _, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}}) require.Error(t, err) - require.ErrorContains(t, err, "invalid dial data request: low priority addr") + require.ErrorContains(t, err, "invalid dial data request") + require.ErrorContains(t, err, "low priority addr") // ipv6 address should work fine with dial data res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}}) require.NoError(t, err) require.Equal(t, Result{ Addr: quicv6Addr, + Idx: 0, Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, }, res) }