diff --git a/p2p/enode/iter.go b/p2p/enode/iter.go index 0f3f4178ee..6aa17e4c09 100644 --- a/p2p/enode/iter.go +++ b/p2p/enode/iter.go @@ -180,7 +180,7 @@ type AsyncFilterFunc func(context.Context, *Node) *Node func AsyncFilter(it Iterator, check AsyncFilterFunc, workers int) Iterator { f := &asyncFilterIter{ it: ensureSourceIter(it), - slots: make(chan struct{}, workers+1), + slots: make(chan struct{}, workers+1), // extra 1 slot to make sure all the goroutines can be completed passed: make(chan iteratorItem), } for range cap(f.slots) { @@ -195,6 +195,9 @@ func AsyncFilter(it Iterator, check AsyncFilterFunc, workers int) Iterator { return case <-f.slots: } + defer func() { + f.slots <- struct{}{} // the iterator has ended + }() // read from the iterator and start checking nodes in parallel // when a node is checked, it will be sent to the passed channel // and the slot will be released @@ -203,7 +206,11 @@ func AsyncFilter(it Iterator, check AsyncFilterFunc, workers int) Iterator { nodeSource := f.it.NodeSource() // check the node async, in a separate goroutine - <-f.slots + select { + case <-ctx.Done(): + return + case <-f.slots: + } go func() { if nn := check(ctx, node); nn != nil { item := iteratorItem{nn, nodeSource} @@ -215,8 +222,6 @@ func AsyncFilter(it Iterator, check AsyncFilterFunc, workers int) Iterator { f.slots <- struct{}{} }() } - // the iterator has ended - f.slots <- struct{}{} }() return f