diff --git a/dial_queue.go b/dial_queue.go new file mode 100644 index 000000000..f4a4fec97 --- /dev/null +++ b/dial_queue.go @@ -0,0 +1,293 @@ +package dht + +import ( + "context" + "math" + "time" + + peer "github.com/libp2p/go-libp2p-peer" + queue "github.com/libp2p/go-libp2p-peerstore/queue" +) + +var ( + // DialQueueMinParallelism is the minimum number of worker dial goroutines that will be alive at any time. + DialQueueMinParallelism = 6 + // DialQueueMaxParallelism is the maximum number of worker dial goroutines that can be alive at any time. + DialQueueMaxParallelism = 20 + // DialQueueMaxIdle is the period that a worker dial goroutine waits before signalling a worker pool downscaling. + DialQueueMaxIdle = 5 * time.Second + // DialQueueScalingMutePeriod is the amount of time to ignore further worker pool scaling events, after one is + // processed. Its role is to reduce jitter. + DialQueueScalingMutePeriod = 1 * time.Second +) + +type dialQueue struct { + ctx context.Context + dialFn func(context.Context, peer.ID) error + + nWorkers int + scalingFactor float64 + + in *queue.ChanQueue + out *queue.ChanQueue + + waitingCh chan waitingCh + dieCh chan struct{} + growCh chan struct{} + shrinkCh chan struct{} +} + +type waitingCh struct { + ch chan<- peer.ID + ts time.Time +} + +// newDialQueue returns an adaptive dial queue that spawns a dynamically sized set of goroutines to preemptively +// stage dials for later handoff to the DHT protocol for RPC. It identifies backpressure on both ends (dial consumers +// and dial producers), and takes compensating action by adjusting the worker pool. +// +// Why? Dialing is expensive. It's orders of magnitude slower than running an RPC on an already-established +// connection, as it requires establishing a TCP connection, multistream handshake, crypto handshake, mux handshake, +// and protocol negotiation. +// +// We start with DialQueueMinParallelism number of workers, and scale up and down based on demand and supply of +// dialled peers. +// +// The following events trigger scaling: +// - we scale up when we can't immediately return a successful dial to a new consumer. +// - we scale down when we've been idle for a while waiting for new dial attempts. +// - we scale down when we complete a dial and realise nobody was waiting for it. +// +// Dialler throttling (e.g. FD limit exceeded) is a concern, as we can easily spin up more workers to compensate, and +// end up adding fuel to the fire. Since we have no deterministic way to detect this for now, we hard-limit concurrency +// to DialQueueMaxParallelism. +func newDialQueue(ctx context.Context, target string, in *queue.ChanQueue, dialFn func(context.Context, peer.ID) error) *dialQueue { + sq := &dialQueue{ + ctx: ctx, + dialFn: dialFn, + nWorkers: DialQueueMinParallelism, + scalingFactor: 1.5, + + in: in, + out: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(target)), + + growCh: make(chan struct{}, 1), + shrinkCh: make(chan struct{}, 1), + waitingCh: make(chan waitingCh), + dieCh: make(chan struct{}, DialQueueMaxParallelism), + } + for i := 0; i < DialQueueMinParallelism; i++ { + go sq.worker() + } + go sq.control() + return sq +} + +func (dq *dialQueue) control() { + var ( + dialled <-chan peer.ID + waiting []waitingCh + lastScalingEvt = time.Now() + ) + + defer func() { + for _, w := range waiting { + close(w.ch) + } + waiting = nil + }() + + for { + // First process any backlog of dial jobs and waiters -- making progress is the priority. + // This block is copied below; couldn't find a more concise way of doing this. + select { + case <-dq.ctx.Done(): + return + case w := <-dq.waitingCh: + waiting = append(waiting, w) + dialled = dq.out.DeqChan + continue // onto the top. + case p, ok := <-dialled: + if !ok { + return // we're done if the ChanQueue is closed, which happens when the context is closed. + } + w := waiting[0] + log.Debugf("delivering dialled peer to DHT; took %dms.", time.Since(w.ts)/time.Millisecond) + w.ch <- p + close(w.ch) + waiting = waiting[1:] + if len(waiting) == 0 { + // no more waiters, so stop consuming dialled jobs. + dialled = nil + } + continue // onto the top. + default: + // there's nothing to process, so proceed onto the main select block. + } + + select { + case <-dq.ctx.Done(): + return + case w := <-dq.waitingCh: + waiting = append(waiting, w) + dialled = dq.out.DeqChan + case p, ok := <-dialled: + if !ok { + return // we're done if the ChanQueue is closed, which happens when the context is closed. + } + w := waiting[0] + log.Debugf("delivering dialled peer to DHT; took %dms.", time.Since(w.ts)/time.Millisecond) + w.ch <- p + close(w.ch) + waiting = waiting[1:] + if len(waiting) == 0 { + // no more waiters, so stop consuming dialled jobs. + dialled = nil + } + case <-dq.growCh: + if time.Since(lastScalingEvt) < DialQueueScalingMutePeriod { + continue + } + dq.grow() + lastScalingEvt = time.Now() + case <-dq.shrinkCh: + if time.Since(lastScalingEvt) < DialQueueScalingMutePeriod { + continue + } + dq.shrink() + lastScalingEvt = time.Now() + } + } +} + +func (dq *dialQueue) Consume() <-chan peer.ID { + ch := make(chan peer.ID, 1) + + select { + case p := <-dq.out.DeqChan: + // short circuit and return a dialled peer if it's immediately available. + ch <- p + close(ch) + return ch + case <-dq.ctx.Done(): + // return a closed channel with no value if we're done. + close(ch) + return ch + default: + } + + // we have no finished dials to return, trigger a scale up. + select { + case dq.growCh <- struct{}{}: + default: + } + + // park the channel until a dialled peer becomes available. + select { + case dq.waitingCh <- waitingCh{ch, time.Now()}: + // all good + case <-dq.ctx.Done(): + // return a closed channel with no value if we're done. + close(ch) + } + return ch +} + +func (dq *dialQueue) grow() { + // no mutex needed as this is only called from the (single-threaded) control loop. + defer func(prev int) { + if prev == dq.nWorkers { + return + } + log.Debugf("grew dial worker pool: %d => %d", prev, dq.nWorkers) + }(dq.nWorkers) + + if dq.nWorkers == DialQueueMaxParallelism { + return + } + target := int(math.Floor(float64(dq.nWorkers) * dq.scalingFactor)) + if target > DialQueueMaxParallelism { + target = DialQueueMinParallelism + } + for ; dq.nWorkers < target; dq.nWorkers++ { + go dq.worker() + } +} + +func (dq *dialQueue) shrink() { + // no mutex needed as this is only called from the (single-threaded) control loop. + defer func(prev int) { + if prev == dq.nWorkers { + return + } + log.Debugf("shrunk dial worker pool: %d => %d", prev, dq.nWorkers) + }(dq.nWorkers) + + if dq.nWorkers == DialQueueMinParallelism { + return + } + target := int(math.Floor(float64(dq.nWorkers) / dq.scalingFactor)) + if target < DialQueueMinParallelism { + target = DialQueueMinParallelism + } + // send as many die signals as workers we have to prune. + for ; dq.nWorkers > target; dq.nWorkers-- { + select { + case dq.dieCh <- struct{}{}: + default: + log.Debugf("too many die signals queued up.") + } + } +} + +func (dq *dialQueue) worker() { + // This idle timer tracks if the environment is slow. If we're waiting to long to acquire a peer to dial, + // it means that the DHT query is progressing slow and we should shrink the worker pool. + idleTimer := time.NewTimer(24 * time.Hour) // placeholder init value which will be overridden immediately. + for { + // trap exit signals first. + select { + case <-dq.ctx.Done(): + return + case <-dq.dieCh: + return + default: + } + + idleTimer.Stop() + select { + case <-idleTimer.C: + default: + } + idleTimer.Reset(DialQueueMaxIdle) + + select { + case <-dq.dieCh: + return + case <-dq.ctx.Done(): + return + case <-idleTimer.C: + // no new dial requests during our idle period; time to scale down. + case p := <-dq.in.DeqChan: + t := time.Now() + if err := dq.dialFn(dq.ctx, p); err != nil { + log.Debugf("discarding dialled peer because of error: %v", err) + continue + } + log.Debugf("dialling %v took %dms (as observed by the dht subsystem).", p, time.Since(t)/time.Millisecond) + waiting := len(dq.waitingCh) + dq.out.EnqChan <- p + if waiting > 0 { + // we have somebody to deliver this value to, so no need to shrink. + continue + } + } + + // scaling down; control only arrives here if the idle timer fires, or if there are no goroutines + // waiting for the value we just produced. + select { + case dq.shrinkCh <- struct{}{}: + default: + } + } +} diff --git a/dial_queue_test.go b/dial_queue_test.go new file mode 100644 index 000000000..04ead9670 --- /dev/null +++ b/dial_queue_test.go @@ -0,0 +1,209 @@ +package dht + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + peer "github.com/libp2p/go-libp2p-peer" + queue "github.com/libp2p/go-libp2p-peerstore/queue" +) + +func init() { + DialQueueScalingMutePeriod = 0 +} + +func TestDialQueueGrowsOnSlowDials(t *testing.T) { + DialQueueMaxIdle = 10 * time.Minute + + in := queue.NewChanQueue(context.Background(), queue.NewXORDistancePQ("test")) + hang := make(chan struct{}) + + var cnt int32 + dialFn := func(ctx context.Context, p peer.ID) error { + atomic.AddInt32(&cnt, 1) + <-hang + return nil + } + + // Enqueue 20 jobs. + for i := 0; i < 20; i++ { + in.EnqChan <- peer.ID(i) + } + + // remove the mute period to grow faster. + dq := newDialQueue(context.Background(), "test", in, dialFn) + + for i := 0; i < 4; i++ { + _ = dq.Consume() + time.Sleep(100 * time.Millisecond) + } + + for i := 0; i < 20; i++ { + if atomic.LoadInt32(&cnt) > int32(DialQueueMinParallelism) { + return + } + time.Sleep(100 * time.Millisecond) + } + + t.Errorf("expected 19 concurrent dials, got %d", atomic.LoadInt32(&cnt)) + +} + +func TestDialQueueShrinksWithNoConsumers(t *testing.T) { + // reduce interference from the other shrink path. + DialQueueMaxIdle = 10 * time.Minute + + in := queue.NewChanQueue(context.Background(), queue.NewXORDistancePQ("test")) + hang := make(chan struct{}) + + wg := new(sync.WaitGroup) + wg.Add(13) + dialFn := func(ctx context.Context, p peer.ID) error { + wg.Done() + <-hang + return nil + } + + dq := newDialQueue(context.Background(), "test", in, dialFn) + + defer func() { + recover() + fmt.Println(dq.nWorkers) + }() + + // acquire 3 consumers, everytime we acquire a consumer, we will grow the pool because no dial job is completed + // and immediately returnable. + for i := 0; i < 3; i++ { + _ = dq.Consume() + } + + // Enqueue 13 jobs, one per worker we'll grow to. + for i := 0; i < 13; i++ { + in.EnqChan <- peer.ID(i) + } + + waitForWg(t, wg, 2*time.Second) + + // Release a few dialFn, but not all of them because downscaling happens when workers detect there are no + // consumers to consume their values. So the other three will be these witnesses. + for i := 0; i < 3; i++ { + hang <- struct{}{} + } + + // allow enough time for signalling and dispatching values to outstanding consumers. + time.Sleep(1 * time.Second) + + // unblock the rest. + for i := 0; i < 10; i++ { + hang <- struct{}{} + } + + wg = new(sync.WaitGroup) + // we should now only have 6 workers, because all the shrink events will have been honoured. + wg.Add(6) + + // enqueue more jobs. + for i := 0; i < 6; i++ { + in.EnqChan <- peer.ID(i) + } + + // let's check we have 6 workers hanging. + waitForWg(t, wg, 2*time.Second) +} + +// Inactivity = workers are idle because the DHT query is progressing slow and is producing too few peers to dial. +func TestDialQueueShrinksWithWhenIdle(t *testing.T) { + DialQueueMaxIdle = 1 * time.Second + + in := queue.NewChanQueue(context.Background(), queue.NewXORDistancePQ("test")) + hang := make(chan struct{}) + + var wg sync.WaitGroup + wg.Add(13) + dialFn := func(ctx context.Context, p peer.ID) error { + wg.Done() + <-hang + return nil + } + + // Enqueue 13 jobs. + for i := 0; i < 13; i++ { + in.EnqChan <- peer.ID(i) + } + + dq := newDialQueue(context.Background(), "test", in, dialFn) + + // keep up to speed with backlog by releasing the dial function every time we acquire a channel. + for i := 0; i < 13; i++ { + ch := dq.Consume() + hang <- struct{}{} + <-ch + time.Sleep(100 * time.Millisecond) + } + + // wait for MaxIdlePeriod. + time.Sleep(1500 * time.Millisecond) + + // we should now only have 6 workers, because all the shrink events will have been honoured. + wg.Add(6) + + // enqueue more jobs + for i := 0; i < 10; i++ { + in.EnqChan <- peer.ID(i) + } + + // let's check we have 6 workers hanging. + waitForWg(t, &wg, 2*time.Second) +} + +func TestDialQueueMutePeriodHonored(t *testing.T) { + DialQueueScalingMutePeriod = 2 * time.Second + + in := queue.NewChanQueue(context.Background(), queue.NewXORDistancePQ("test")) + hang := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(6) + dialFn := func(ctx context.Context, p peer.ID) error { + wg.Done() + <-hang + return nil + } + + // Enqueue a bunch of jobs. + for i := 0; i < 20; i++ { + in.EnqChan <- peer.ID(i) + } + + dq := newDialQueue(context.Background(), "test", in, dialFn) + + // pick up three consumers. + for i := 0; i < 3; i++ { + _ = dq.Consume() + time.Sleep(100 * time.Millisecond) + } + + time.Sleep(500 * time.Millisecond) + + // we'll only have 6 workers because the grow signals have been ignored. + waitForWg(t, &wg, 2*time.Second) +} + +func waitForWg(t *testing.T, wg *sync.WaitGroup, wait time.Duration) { + t.Helper() + + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-time.After(wait): + t.Error("timeout while waiting for WaitGroup") + case <-done: + } +} diff --git a/query.go b/query.go index 8794deaa9..9cc7076b8 100644 --- a/query.go +++ b/query.go @@ -74,6 +74,7 @@ type dhtQueryRunner struct { query *dhtQuery // query to run peersSeen *pset.PeerSet // all peers queried. prevent querying same peer 2x peersQueried *pset.PeerSet // peers successfully connected to and queried + peersDialed *dialQueue // peers we have dialed to peersToQuery *queue.ChanQueue // peers remaining to be queried peersRemaining todoctr.Counter // peersToQuery + currently processing @@ -92,15 +93,18 @@ type dhtQueryRunner struct { func newQueryRunner(q *dhtQuery) *dhtQueryRunner { proc := process.WithParent(process.Background()) ctx := ctxproc.OnClosingContext(proc) - return &dhtQueryRunner{ + peersToQuery := queue.NewChanQueue(ctx, queue.NewXORDistancePQ(string(q.key))) + r := &dhtQueryRunner{ query: q, - peersToQuery: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(string(q.key))), peersRemaining: todoctr.NewSyncCounter(), peersSeen: pset.New(), peersQueried: pset.New(), rateLimit: make(chan struct{}, q.concurrency), + peersToQuery: peersToQuery, proc: proc, } + r.peersDialed = newDialQueue(ctx, q.key, peersToQuery, r.dialPeer) + return r } func (r *dhtQueryRunner) Run(ctx context.Context, peers []peer.ID) (*dhtQueryResult, error) { @@ -192,7 +196,6 @@ func (r *dhtQueryRunner) addPeerToQuery(next peer.ID) { func (r *dhtQueryRunner) spawnWorkers(proc process.Process) { for { - select { case <-r.peersRemaining.Done(): return @@ -201,14 +204,13 @@ func (r *dhtQueryRunner) spawnWorkers(proc process.Process) { return case <-r.rateLimit: + ch := r.peersDialed.Consume() select { - case p, more := <-r.peersToQuery.DeqChan: - if !more { - // Put this back so we can finish any outstanding queries. - r.rateLimit <- struct{}{} - return // channel closed. + case p, ok := <-ch: + if !ok { + // this signals context cancellation. + return } - // do it as a child func to make sure Run exits // ONLY AFTER spawn workers has exited. proc.Go(func(proc process.Process) { @@ -223,6 +225,36 @@ func (r *dhtQueryRunner) spawnWorkers(proc process.Process) { } } +func (r *dhtQueryRunner) dialPeer(ctx context.Context, p peer.ID) error { + // short-circuit if we're already connected. + if r.query.dht.host.Network().Connectedness(p) == inet.Connected { + return nil + } + + log.Debug("not connected. dialing.") + notif.PublishQueryEvent(r.runCtx, ¬if.QueryEvent{ + Type: notif.DialingPeer, + ID: p, + }) + + pi := pstore.PeerInfo{ID: p} + if err := r.query.dht.host.Connect(ctx, pi); err != nil { + log.Debugf("error connecting: %s", err) + notif.PublishQueryEvent(r.runCtx, ¬if.QueryEvent{ + Type: notif.QueryError, + Extra: err.Error(), + ID: p, + }) + + r.Lock() + r.errs = append(r.errs, err) + r.Unlock() + return err + } + log.Debugf("connected. dial success.") + return nil +} + func (r *dhtQueryRunner) queryPeer(proc process.Process, p peer.ID) { // ok let's do this! @@ -236,42 +268,6 @@ func (r *dhtQueryRunner) queryPeer(proc process.Process, p peer.ID) { r.rateLimit <- struct{}{} }() - // make sure we're connected to the peer. - // FIXME abstract away into the network layer - // Note: Failure to connect in this block will cause the function to - // short circuit. - if r.query.dht.host.Network().Connectedness(p) == inet.NotConnected { - log.Debug("not connected. dialing.") - - notif.PublishQueryEvent(r.runCtx, ¬if.QueryEvent{ - Type: notif.DialingPeer, - ID: p, - }) - // while we dial, we do not take up a rate limit. this is to allow - // forward progress during potentially very high latency dials. - r.rateLimit <- struct{}{} - - pi := pstore.PeerInfo{ID: p} - - if err := r.query.dht.host.Connect(ctx, pi); err != nil { - log.Debugf("Error connecting: %s", err) - - notif.PublishQueryEvent(r.runCtx, ¬if.QueryEvent{ - Type: notif.QueryError, - Extra: err.Error(), - ID: p, - }) - - r.Lock() - r.errs = append(r.errs, err) - r.Unlock() - <-r.rateLimit // need to grab it again, as we deferred. - return - } - <-r.rateLimit // need to grab it again, as we deferred. - log.Debugf("connected. dial success.") - } - // finally, run the query against this peer res, err := r.query.qfunc(ctx, p)