diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go
index 1e56687a6c78..28b0a05a2183 100644
--- a/cmd/devp2p/discv4cmd.go
+++ b/cmd/devp2p/discv4cmd.go
@@ -38,6 +38,7 @@ var (
discv4PingCommand,
discv4RequestRecordCommand,
discv4ResolveCommand,
+ discv4RandomWalkCommand,
},
}
discv4PingCommand = cli.Command{
@@ -56,6 +57,12 @@ var (
Action: discv4Resolve,
Flags: []cli.Flag{bootnodesFlag},
}
+ discv4RandomWalkCommand = cli.Command{
+ Name: "randomwalk",
+ Usage: "Prints random nodes found in the DHT",
+ Action: discv4RandomNodes,
+ Flags: []cli.Flag{bootnodesFlag},
+ }
)
var bootnodesFlag = cli.StringFlag{
@@ -104,6 +111,24 @@ func discv4Resolve(ctx *cli.Context) error {
return nil
}
+func discv4RandomNodes(ctx *cli.Context) error {
+ bootnodes, err := parseBootnodes(ctx)
+ if err != nil {
+ return err
+ }
+ disc, err := startV4(bootnodes)
+ if err != nil {
+ return err
+ }
+ defer disc.Close()
+
+ it := disc.RandomNodes(nil)
+ for it.Next() {
+ fmt.Println(it.Node())
+ }
+ return nil
+}
+
func getNodeArgAndStartV4(ctx *cli.Context) (*enode.Node, *discover.UDPv4, error) {
if ctx.NArg() != 1 {
return nil, nil, fmt.Errorf("missing node as command-line argument")
diff --git a/p2p/dial.go b/p2p/dial.go
index 8dee5063f1d5..b2e20ba338f5 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -23,6 +23,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/discutil"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/netutil"
)
@@ -33,12 +34,10 @@ const (
// private networks.
dialHistoryExpiration = inboundThrottleTime + 5*time.Second
- // Discovery lookups are throttled and can only run
- // once every few seconds.
- lookupInterval = 4 * time.Second
+ // Timeout for NextNode on the discovery iterator.
+ discoveryTimeout = 4 * time.Second
- // If no peers are found for this amount of time, the initial bootnodes are
- // attempted to be connected.
+ // If no peers are found for this amount of time, the initial bootnodes are dialed.
fallbackInterval = 20 * time.Second
// Endpoint resolution is throttled with bounded backoff.
@@ -52,6 +51,10 @@ type NodeDialer interface {
Dial(*enode.Node) (net.Conn, error)
}
+type nodeResolver interface {
+ Resolve(*enode.Node) *enode.Node
+}
+
// TCPDialer implements the NodeDialer interface by using a net.Dialer to
// create TCP connections to nodes in the network
type TCPDialer struct {
@@ -69,7 +72,6 @@ func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) {
// of the main loop in Server.run.
type dialstate struct {
maxDynDials int
- ntab discoverTable
netrestrict *netutil.Netlist
self enode.ID
bootnodes []*enode.Node // default dials when there are no peers
@@ -79,55 +81,23 @@ type dialstate struct {
lookupRunning bool
dialing map[enode.ID]connFlag
lookupBuf []*enode.Node // current discovery lookup results
- randomNodes []*enode.Node // filled from Table
static map[enode.ID]*dialTask
hist expHeap
}
-type discoverTable interface {
- Close()
- Resolve(*enode.Node) *enode.Node
- LookupRandom() []*enode.Node
- ReadRandomNodes([]*enode.Node) int
-}
-
type task interface {
Do(*Server)
}
-// A dialTask is generated for each node that is dialed. Its
-// fields cannot be accessed while the task is running.
-type dialTask struct {
- flags connFlag
- dest *enode.Node
- lastResolved time.Time
- resolveDelay time.Duration
-}
-
-// discoverTask runs discovery table operations.
-// Only one discoverTask is active at any time.
-// discoverTask.Do performs a random lookup.
-type discoverTask struct {
- results []*enode.Node
-}
-
-// A waitExpireTask is generated if there are no other tasks
-// to keep the loop in Server.run ticking.
-type waitExpireTask struct {
- time.Duration
-}
-
-func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
+func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
- ntab: ntab,
self: self,
netrestrict: cfg.NetRestrict,
log: cfg.Logger,
static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag),
bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
- randomNodes: make([]*enode.Node, maxdyn/2),
}
copy(s.bootnodes, cfg.BootstrapNodes)
if s.log == nil {
@@ -151,10 +121,6 @@ func (s *dialstate) removeStatic(n *enode.Node) {
}
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
- if s.start.IsZero() {
- s.start = now
- }
-
var newtasks []task
addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil {
@@ -166,20 +132,9 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
return true
}
- // Compute number of dynamic dials necessary at this point.
- needDynDials := s.maxDynDials
- for _, p := range peers {
- if p.rw.is(dynDialedConn) {
- needDynDials--
- }
- }
- for _, flag := range s.dialing {
- if flag&dynDialedConn != 0 {
- needDynDials--
- }
+ if s.start.IsZero() {
+ s.start = now
}
-
- // Expire the dial history on every invocation.
s.hist.expire(now)
// Create dials for static nodes if they are not connected.
@@ -194,6 +149,20 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
newtasks = append(newtasks, t)
}
}
+
+ // Compute number of dynamic dials needed.
+ needDynDials := s.maxDynDials
+ for _, p := range peers {
+ if p.rw.is(dynDialedConn) {
+ needDynDials--
+ }
+ }
+ for _, flag := range s.dialing {
+ if flag&dynDialedConn != 0 {
+ needDynDials--
+ }
+ }
+
// If we don't have any peers whatsoever, try to dial a random bootnode. This
// scenario is useful for the testnet (and private networks) where the discovery
// table might be full of mostly bad peers, making it hard to find good ones.
@@ -201,24 +170,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
bootnode := s.bootnodes[0]
s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...)
s.bootnodes = append(s.bootnodes, bootnode)
-
if addDial(dynDialedConn, bootnode) {
needDynDials--
}
}
- // Use random nodes from the table for half of the necessary
- // dynamic dials.
- randomCandidates := needDynDials / 2
- if randomCandidates > 0 {
- n := s.ntab.ReadRandomNodes(s.randomNodes)
- for i := 0; i < randomCandidates && i < n; i++ {
- if addDial(dynDialedConn, s.randomNodes[i]) {
- needDynDials--
- }
- }
- }
- // Create dynamic dials from random lookup results, removing tried
- // items from the result buffer.
+
+ // Create dynamic dials from discovery results.
i := 0
for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
if addDial(dynDialedConn, s.lookupBuf[i]) {
@@ -226,10 +183,11 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
}
}
s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
+
// Launch a discovery lookup if more candidates are needed.
if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
s.lookupRunning = true
- newtasks = append(newtasks, &discoverTask{})
+ newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)})
}
// Launch a timer to wait for the next node to expire if all
@@ -279,6 +237,15 @@ func (s *dialstate) taskDone(t task, now time.Time) {
}
}
+// A dialTask is generated for each node that is dialed. Its
+// fields cannot be accessed while the task is running.
+type dialTask struct {
+ flags connFlag
+ dest *enode.Node
+ lastResolved time.Time
+ resolveDelay time.Duration
+}
+
func (t *dialTask) Do(srv *Server) {
if t.dest.Incomplete() {
if !t.resolve(srv) {
@@ -304,7 +271,7 @@ func (t *dialTask) Do(srv *Server) {
// discovery network with useless queries for nodes that don't exist.
// The backoff delay resets when the node is found.
func (t *dialTask) resolve(srv *Server) bool {
- if srv.ntab == nil {
+ if srv.staticNodeResolver == nil {
srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
return false
}
@@ -314,7 +281,7 @@ func (t *dialTask) resolve(srv *Server) bool {
if time.Since(t.lastResolved) < t.resolveDelay {
return false
}
- resolved := srv.ntab.Resolve(t.dest)
+ resolved := srv.staticNodeResolver.Resolve(t.dest)
t.lastResolved = time.Now()
if resolved == nil {
t.resolveDelay *= 2
@@ -350,26 +317,34 @@ func (t *dialTask) String() string {
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
}
+// discoverTask runs discovery table operations.
+// Only one discoverTask is active at any time.
+// discoverTask.Do performs a random lookup.
+type discoverTask struct {
+ want int
+ results []*enode.Node
+}
+
func (t *discoverTask) Do(srv *Server) {
- // newTasks generates a lookup task whenever dynamic dials are
- // necessary. Lookups need to take some time, otherwise the
- // event loop spins too fast.
- next := srv.lastLookup.Add(lookupInterval)
- if now := time.Now(); now.Before(next) {
- time.Sleep(next.Sub(now))
- }
- srv.lastLookup = time.Now()
- t.results = srv.ntab.LookupRandom()
+ t.results = discutil.ReadNodes(srv.discmix, t.want)
}
func (t *discoverTask) String() string {
- s := "discovery lookup"
+ s := "discovery query"
if len(t.results) > 0 {
s += fmt.Sprintf(" (%d results)", len(t.results))
+ } else {
+ s += fmt.Sprintf(" (want %d)", t.want)
}
return s
}
+// A waitExpireTask is generated if there are no other tasks
+// to keep the loop in Server.run ticking.
+type waitExpireTask struct {
+ time.Duration
+}
+
func (t waitExpireTask) Do(*Server) {
time.Sleep(t.Duration)
}
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
index de8fc4a6e3e6..6189ec4d0b85 100644
--- a/p2p/dial_test.go
+++ b/p2p/dial_test.go
@@ -73,7 +73,7 @@ func runDialTest(t *testing.T, test dialtest) {
t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
}
- t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new)))
+ t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new)))
// Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second)
@@ -81,19 +81,11 @@ func runDialTest(t *testing.T, test dialtest) {
}
}
-type fakeTable []*enode.Node
-
-func (t fakeTable) Self() *enode.Node { return new(enode.Node) }
-func (t fakeTable) Close() {}
-func (t fakeTable) LookupRandom() []*enode.Node { return nil }
-func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil }
-func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) }
-
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, fakeTable{}, 5, config),
+ init: newDialState(enode.ID{}, 5, config),
rounds: []round{
// A discovery query is launched.
{
@@ -102,7 +94,9 @@ func TestDialStateDynDial(t *testing.T) {
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
- new: []task{&discoverTask{}},
+ new: []task{
+ &discoverTask{want: 3},
+ },
},
// Dynamic dials are launched when it completes.
{
@@ -188,7 +182,7 @@ func TestDialStateDynDial(t *testing.T) {
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)},
- &discoverTask{},
+ &discoverTask{want: 2},
},
},
// Peer 7 is connected, but there still aren't enough dynamic peers
@@ -218,7 +212,7 @@ func TestDialStateDynDial(t *testing.T) {
&discoverTask{},
},
new: []task{
- &discoverTask{},
+ &discoverTask{want: 2},
},
},
},
@@ -235,35 +229,37 @@ func TestDialStateDynDialBootnode(t *testing.T) {
},
Logger: testlog.Logger(t, log.LvlTrace),
}
- table := fakeTable{
- newNode(uintID(4), nil),
- newNode(uintID(5), nil),
- newNode(uintID(6), nil),
- newNode(uintID(7), nil),
- newNode(uintID(8), nil),
- }
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, table, 5, config),
+ init: newDialState(enode.ID{}, 5, config),
rounds: []round{
- // 2 dynamic dials attempted, bootnodes pending fallback interval
{
new: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
- &discoverTask{},
+ &discoverTask{want: 5},
},
},
- // No dials succeed, bootnodes still pending fallback interval
{
done: []task{
+ &discoverTask{
+ results: []*enode.Node{
+ newNode(uintID(4), nil),
+ newNode(uintID(5), nil),
+ },
+ },
+ },
+ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
+ &discoverTask{want: 3},
},
},
// No dials succeed, bootnodes still pending fallback interval
{},
- // No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached
+ // 1 bootnode attempted as fallback interval was reached
{
+ done: []task{
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
+ },
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
},
@@ -275,15 +271,12 @@ func TestDialStateDynDialBootnode(t *testing.T) {
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, 3rd bootnode is attempted
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
@@ -293,115 +286,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
- },
- new: []task{},
- },
- // Random dial succeeds, no more bootnodes are attempted
- {
- new: []task{
- &waitExpireTask{3 * time.Second},
- },
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
- },
- done: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- },
- },
- },
- })
-}
-
-func TestDialStateDynDialFromTable(t *testing.T) {
- // This table always returns the same random nodes
- // in the order given below.
- table := fakeTable{
- newNode(uintID(1), nil),
- newNode(uintID(2), nil),
- newNode(uintID(3), nil),
- newNode(uintID(4), nil),
- newNode(uintID(5), nil),
- newNode(uintID(6), nil),
- newNode(uintID(7), nil),
- newNode(uintID(8), nil),
- }
-
- runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
- rounds: []round{
- // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
- {
- new: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
- &discoverTask{},
- },
- },
- // Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
- {
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- },
- done: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&discoverTask{results: []*enode.Node{
- newNode(uintID(10), nil),
- newNode(uintID(11), nil),
- newNode(uintID(12), nil),
+ newNode(uintID(6), nil),
}},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
- &discoverTask{},
- },
- },
- // Dialing nodes 3,4,5 fails. The dials from the lookup succeed.
- {
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
- },
- done: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
+ &discoverTask{want: 4},
},
},
- // Waiting for expiry. No waitExpireTask is launched because the
- // discovery query is still running.
+ // Random dial succeeds, no more bootnodes are attempted
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
- },
- },
- // Nodes 3,4 are not tried again because only the first two
- // returned random nodes (nodes 1,2) are tried and they're
- // already connected.
- {
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}},
},
},
},
@@ -416,11 +313,11 @@ func newNode(id enode.ID, ip net.IP) *enode.Node {
return enode.SignNull(&r, id)
}
-// This test checks that candidates that do not match the netrestrict list are not dialed.
+// // This test checks that candidates that do not match the netrestrict list are not dialed.
func TestDialStateNetRestrict(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
- table := fakeTable{
+ nodes := []*enode.Node{
newNode(uintID(1), net.ParseIP("127.0.0.1")),
newNode(uintID(2), net.ParseIP("127.0.0.2")),
newNode(uintID(3), net.ParseIP("127.0.0.3")),
@@ -434,12 +331,23 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}),
+ init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}),
rounds: []round{
{
new: []task{
- &dialTask{flags: dynDialedConn, dest: table[4]},
- &discoverTask{},
+ &discoverTask{want: 10},
+ },
+ },
+ {
+ done: []task{
+ &discoverTask{results: nodes},
+ },
+ new: []task{
+ &dialTask{flags: dynDialedConn, dest: nodes[4]},
+ &dialTask{flags: dynDialedConn, dest: nodes[5]},
+ &dialTask{flags: dynDialedConn, dest: nodes[6]},
+ &dialTask{flags: dynDialedConn, dest: nodes[7]},
+ &discoverTask{want: 6},
},
},
},
@@ -459,7 +367,7 @@ func TestDialStateStaticDial(t *testing.T) {
Logger: testlog.Logger(t, log.LvlTrace),
}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, fakeTable{}, 0, config),
+ init: newDialState(enode.ID{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -544,7 +452,7 @@ func TestDialStateCache(t *testing.T) {
Logger: testlog.Logger(t, log.LvlTrace),
}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, fakeTable{}, 0, config),
+ init: newDialState(enode.ID{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -618,8 +526,8 @@ func TestDialResolve(t *testing.T) {
Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
}
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
- table := &resolveMock{answer: resolved}
- state := newDialState(enode.ID{}, table, 0, config)
+ resolver := &resolveMock{answer: resolved}
+ state := newDialState(enode.ID{}, 0, config)
// Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil)
@@ -630,10 +538,14 @@ func TestDialResolve(t *testing.T) {
}
// Now run the task, it should resolve the ID once.
- srv := &Server{ntab: table, log: config.Logger, Config: *config}
+ srv := &Server{
+ Config: *config,
+ log: config.Logger,
+ staticNodeResolver: resolver,
+ }
tasks[0].Do(srv)
- if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
- t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)
+ if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) {
+ t.Fatalf("wrong resolve calls, got %v", resolver.calls)
}
// Report it as done to the dialer, which should update the static node record.
@@ -666,18 +578,13 @@ func uintID(i uint32) enode.ID {
return id
}
-// implements discoverTable for TestDialResolve
+// for TestDialResolve
type resolveMock struct {
- resolveCalls []*enode.Node
- answer *enode.Node
+ calls []*enode.Node
+ answer *enode.Node
}
func (t *resolveMock) Resolve(n *enode.Node) *enode.Node {
- t.resolveCalls = append(t.resolveCalls, n)
+ t.calls = append(t.calls, n)
return t.answer
}
-
-func (t *resolveMock) Self() *enode.Node { return new(enode.Node) }
-func (t *resolveMock) Close() {}
-func (t *resolveMock) LookupRandom() []*enode.Node { return nil }
-func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 }
diff --git a/p2p/discover/common.go b/p2p/discover/common.go
index 3c080359fdf8..5ec323809ba1 100644
--- a/p2p/discover/common.go
+++ b/p2p/discover/common.go
@@ -18,13 +18,16 @@ package discover
import (
"crypto/ecdsa"
+ "fmt"
"net"
+ "sync"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/netutil"
)
+// UDPConn is a network connection on which discovery can operate.
type UDPConn interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
@@ -32,7 +35,7 @@ type UDPConn interface {
LocalAddr() net.Addr
}
-// Config holds Table-related settings.
+// Config holds settings for the discovery listener.
type Config struct {
// These settings are required and configure the UDP listener:
PrivateKey *ecdsa.PrivateKey
@@ -50,8 +53,211 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
}
// ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled
-// channel if configured.
+// channel if configured. This is exported for internal use, do not use this type.
type ReadPacket struct {
Data []byte
Addr *net.UDPAddr
}
+
+type lookupFunc func(cancel <-chan struct{}, seenNode func(*enode.Node))
+
+// lookupWalker performs recursive lookups, walking the DHT.
+// It manages a set iterators which receive lookup results as they are found.
+type lookupWalker struct {
+ lookup lookupFunc
+ closeCh chan struct{}
+
+ mu sync.Mutex
+ cond *sync.Cond
+ wg sync.WaitGroup
+ iters map[*lookupIterator]struct{}
+}
+
+func newLookupWalker(fn lookupFunc) *lookupWalker {
+ w := &lookupWalker{
+ lookup: fn,
+ closeCh: make(chan struct{}),
+ iters: make(map[*lookupIterator]struct{}),
+ }
+ w.cond = sync.NewCond(&w.mu)
+ w.wg.Add(1)
+ go w.loop()
+ return w
+}
+
+func (w *lookupWalker) close() {
+ close(w.closeCh)
+ w.wg.Wait()
+}
+
+// loop schedules lookups. It ensures a lookup is running while
+// any live iterator needs more nodes.
+func (w *lookupWalker) loop() {
+ var (
+ done = make(chan struct{})
+ cancel = make(chan struct{})
+ running bool
+ )
+ for {
+ if !running {
+ go w.runLookup(cancel, done)
+ }
+ select {
+ case <-done:
+ case <-w.closeCh:
+ if running {
+ close(cancel)
+ <-done
+ }
+ goto shutdown
+ }
+ }
+
+shutdown:
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ for it := range w.iters {
+ it.close()
+ }
+ w.wg.Done()
+}
+
+func (w *lookupWalker) runLookup(cancel, done chan struct{}) {
+ w.lookup(cancel, w.foundNode)
+ done <- struct{}{}
+}
+
+func (w *lookupWalker) foundNode(n *enode.Node) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ for it := range w.iters {
+ it.deliver(n)
+ fmt.Println("delivered", len(it.buf), it.needsNodes())
+ }
+ for !anyIterNeedsNodes(w.iters) {
+ w.cond.Wait()
+ }
+}
+
+func (w *lookupWalker) newIterator(filter filterFunc) *lookupIterator {
+ it := newLookupIterator(w, filter)
+ it.walker.add(it)
+ return it
+}
+
+func (w *lookupWalker) add(it *lookupIterator) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ w.iters[it] = struct{}{}
+ w.unblockLookup()
+}
+
+func (w *lookupWalker) remove(it *lookupIterator) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ delete(w.iters, it)
+ w.unblockLookup()
+}
+
+func (w *lookupWalker) unblockLookup() {
+ w.cond.Signal()
+}
+
+func anyIterNeedsNodes(iters map[*lookupIterator]struct{}) bool {
+ for it := range iters {
+ if it.needsNodes() {
+ return true
+ }
+ }
+ return false
+}
+
+// lookupIterator is a sequence of discovered nodes.
+type lookupIterator struct {
+ cur *enode.Node
+ walker *lookupWalker
+ filter filterFunc
+ mu sync.Mutex
+ cond *sync.Cond
+ buf []*enode.Node
+}
+
+const lookupIteratorBuffer = 100
+
+type filterFunc func(*enode.Node) bool
+
+func newLookupIterator(w *lookupWalker, filter filterFunc) *lookupIterator {
+ if filter == nil {
+ filter = func(*enode.Node) bool { return true }
+ }
+ it := &lookupIterator{
+ walker: w,
+ filter: filter,
+ buf: make([]*enode.Node, 0, lookupIteratorBuffer),
+ }
+ it.cond = sync.NewCond(&it.mu)
+ return it
+}
+
+func (it *lookupIterator) Next() bool {
+ it.cur = nil
+
+ // Wait for the buffer to be filled.
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ for it.buf != nil && len(it.buf) == 0 {
+ it.cond.Wait()
+ }
+ if it.buf == nil {
+ return false // closed
+ }
+ it.cur = it.buf[0]
+ copy(it.buf, it.buf[1:])
+ it.buf = it.buf[:len(it.buf)-1]
+ fmt.Println("read node", len(it.buf))
+ it.walker.unblockLookup()
+ return true
+}
+
+func (it *lookupIterator) Node() *enode.Node {
+ return it.cur
+}
+
+func (it *lookupIterator) Close() {
+ it.walker.remove(it)
+ it.close()
+}
+
+func (it *lookupIterator) close() {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ if it.buf != nil {
+ it.buf = nil
+ it.cond.Signal()
+ }
+}
+
+// deliver places a node into the iterator buffer.
+func (it *lookupIterator) deliver(n *enode.Node) bool {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ if it.buf == nil || !it.filter(n) {
+ return true
+ }
+ if len(it.buf) == cap(it.buf) {
+ return false
+ }
+ it.buf = append(it.buf, n)
+ it.cond.Signal()
+ return true
+}
+
+// needsNodes reports whether the iterator is low on nodes.
+func (it *lookupIterator) needsNodes() bool {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ return len(it.buf) < lookupIteratorBuffer/3
+}
diff --git a/p2p/discover/common_test.go b/p2p/discover/common_test.go
new file mode 100644
index 000000000000..e257d9938c18
--- /dev/null
+++ b/p2p/discover/common_test.go
@@ -0,0 +1,160 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package discover
+
+import (
+ "encoding/binary"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/p2p/discutil"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+)
+
+// This test checks basic operation of the lookup iterator.
+func TestLookupIterator(t *testing.T) {
+ var (
+ test = newLookupWalkerTest()
+ testNodes = makeTestNodes(lookupIteratorBuffer)
+ wg sync.WaitGroup
+ )
+
+ testIterator := func(it discutil.Iterator) {
+ defer wg.Done()
+
+ // Check reading nodes:
+ nodes := discutil.ReadNodes(it, 20)
+ sortByID(nodes)
+ if err := checkNodesEqual(nodes, testNodes[:20]); err != nil {
+ t.Error(err)
+ }
+ nodes = discutil.ReadNodes(it, 20)
+ sortByID(nodes)
+ if err := checkNodesEqual(nodes, testNodes[20:40]); err != nil {
+ t.Error(err)
+ }
+
+ // Check close:
+ it.Close()
+ if it.Next() {
+ t.Error("Next returned true after close")
+ }
+ if it.Node() != nil {
+ t.Error("iterator has non-nil node after close")
+ }
+ it.Close() // shouldn't crash
+ }
+
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go testIterator(test.newIterator(nil))
+ }
+
+ test.serveOneLookup(testNodes[:10])
+ test.serveOneLookup(testNodes[10:20])
+ test.serveOneLookup(testNodes[20:40])
+ wg.Wait()
+
+ test.close()
+}
+
+func TestLookupIteratorClose(t *testing.T) {
+ test := newLookupWalkerTest()
+ defer test.close()
+ it := test.newIterator(nil)
+
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ it.Close()
+ }()
+ it.Next()
+}
+
+// This test checks that the iterator kicks off a lookup when Next is called.
+func TestLookupIteratorDrained(t *testing.T) {
+ var (
+ test = newLookupWalkerTest()
+ it = test.newIterator(nil)
+ testNodes = makeTestNodes(2 * lookupIteratorBuffer)
+ )
+
+ test.serveOneLookup(testNodes[:lookupIteratorBuffer])
+ nodes := discutil.ReadNodes(it, lookupIteratorBuffer)
+ sortByID(nodes)
+ if err := checkNodesEqual(nodes, testNodes[:lookupIteratorBuffer]); err != nil {
+ t.Fatal(err)
+ }
+
+ // Here the iterator buffer is drained and no lookup is running.
+
+ // Request more nodes. This needs to start another lookup.
+ go test.serveOneLookup(testNodes[lookupIteratorBuffer:])
+ nodes = discutil.ReadNodes(it, 10)
+ sortByID(nodes)
+ if err := checkNodesEqual(nodes, testNodes[lookupIteratorBuffer:lookupIteratorBuffer+10]); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func makeTestNodes(n int) []*enode.Node {
+ nodes := make([]*enode.Node, n)
+ for i := range nodes {
+ var nodeID enode.ID
+ binary.BigEndian.PutUint64(nodeID[:], uint64(i))
+ nodes[i] = enode.SignNull(new(enr.Record), nodeID)
+ }
+ return nodes
+}
+
+type lookupWalkerTest struct {
+ *lookupWalker
+ running int32
+ nodes chan []*enode.Node
+}
+
+func newLookupWalkerTest() *lookupWalkerTest {
+ wt := &lookupWalkerTest{nodes: make(chan []*enode.Node)}
+ wt.lookupWalker = newLookupWalker(wt.lookupFunc)
+ return wt
+}
+
+// serveOneLookup allows one lookupFunc call to happen and makes it find
+// the given nodes.
+func (t *lookupWalkerTest) serveOneLookup(nodes []*enode.Node) {
+ t.nodes <- nodes
+ <-t.nodes
+}
+
+func (t *lookupWalkerTest) lookupFunc(cancel <-chan struct{}, callback func(*enode.Node)) {
+ if atomic.AddInt32(&t.running, 1) != 1 {
+ panic("spawned more than one instance of lookupFunc")
+ }
+ defer atomic.AddInt32(&t.running, -1)
+
+ select {
+ case nodes := <-t.nodes:
+ for _, n := range nodes {
+ callback(n)
+ }
+ t.nodes <- nil
+ case <-cancel:
+ return
+ }
+}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index e0a46792b44b..e5a5793e358f 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -147,35 +147,18 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
- // Find all non-empty buckets and get a fresh slice of their entries.
- var buckets [][]*node
+ var nodes []*enode.Node
for _, b := range &tab.buckets {
- if len(b.entries) > 0 {
- buckets = append(buckets, b.entries)
+ for _, n := range b.entries {
+ nodes = append(nodes, unwrapNode(n))
}
}
- if len(buckets) == 0 {
- return 0
- }
- // Shuffle the buckets.
- for i := len(buckets) - 1; i > 0; i-- {
- j := tab.rand.Intn(len(buckets))
- buckets[i], buckets[j] = buckets[j], buckets[i]
- }
- // Move head of each bucket into buf, removing buckets that become empty.
- var i, j int
- for ; i < len(buf); i, j = i+1, (j+1)%len(buckets) {
- b := buckets[j]
- buf[i] = unwrapNode(b[0])
- buckets[j] = b[1:]
- if len(b) == 1 {
- buckets = append(buckets[:j], buckets[j+1:]...)
- }
- if len(buckets) == 0 {
- break
- }
+ // Shuffle.
+ for i := 0; i < len(nodes); i++ {
+ j := tab.rand.Intn(len(nodes))
+ nodes[i], nodes[j] = nodes[j], nodes[i]
}
- return i + 1
+ return copy(buf, nodes)
}
// getNode returns the node with the given ID or nil if it isn't in the table.
diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go
index 8e5fc7374b47..2adfd0dc0ec1 100644
--- a/p2p/discover/table_util_test.go
+++ b/p2p/discover/table_util_test.go
@@ -17,11 +17,14 @@
package discover
import (
+ "bytes"
"crypto/ecdsa"
"encoding/hex"
+ "errors"
"fmt"
"math/rand"
"net"
+ "reflect"
"sort"
"sync"
@@ -169,6 +172,28 @@ func hasDuplicates(slice []*node) bool {
return false
}
+func checkNodesEqual(got, want []*enode.Node) error {
+ if reflect.DeepEqual(got, want) {
+ return nil
+ }
+ output := new(bytes.Buffer)
+ fmt.Fprintf(output, "got %d nodes:\n", len(got))
+ for _, n := range got {
+ fmt.Fprintf(output, " %v %v\n", n.ID(), n)
+ }
+ fmt.Fprintf(output, "want %d:\n", len(want))
+ for _, n := range want {
+ fmt.Fprintf(output, " %v %v\n", n.ID(), n)
+ }
+ return errors.New(output.String())
+}
+
+func sortByID(nodes []*enode.Node) {
+ sort.Slice(nodes, func(i, j int) bool {
+ return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes())
+ })
+}
+
func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
return sort.SliceIsSorted(slice, func(i, j int) bool {
return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0
diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go
index b2a5d85cf421..ecfc98388117 100644
--- a/p2p/discover/v4_udp.go
+++ b/p2p/discover/v4_udp.go
@@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/discutil"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
@@ -202,6 +203,7 @@ type UDPv4 struct {
localNode *enode.LocalNode
db *enode.DB
tab *Table
+ randomWalk *lookupWalker
closeOnce sync.Once
wg sync.WaitGroup
@@ -270,6 +272,7 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
if t.log == nil {
t.log = log.Root()
}
+
tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log)
if err != nil {
return nil, err
@@ -277,6 +280,8 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
t.tab = tab
go tab.loop()
+ t.randomWalk = newLookupWalker(t.randomLookupWithCallback)
+
t.wg.Add(2)
go t.loop()
go t.readLoop(cfg.Unhandled)
@@ -295,47 +300,57 @@ func (t *UDPv4) Close() {
t.conn.Close()
t.wg.Wait()
t.tab.close()
+ t.randomWalk.close()
})
}
-// ReadRandomNodes reads random nodes from the local table.
-func (t *UDPv4) ReadRandomNodes(buf []*enode.Node) int {
- return t.tab.ReadRandomNodes(buf)
+// RandomNodes is an iterator yielding nodes from a random walk of the DHT.
+//
+// All iterators share the same random walk to minimize network traffic. Discovered nodes
+// are checked against the filter function and returned by the iterator only when the
+// filter returns true.
+func (t *UDPv4) RandomNodes(filter func(*enode.Node) bool) discutil.Iterator {
+ return t.randomWalk.newIterator(filter)
}
// LookupRandom finds random nodes in the network.
-func (t *UDPv4) LookupRandom() []*enode.Node {
+func (t *UDPv4) randomLookupWithCallback(cancel <-chan struct{}, callback func(*enode.Node)) {
if t.tab.len() == 0 {
// All nodes were dropped, refresh. The very first query will hit this
// case and run the bootstrapping logic.
<-t.tab.refresh()
}
- return t.lookupRandom()
+ var target encPubkey
+ crand.Read(target[:])
+ t.lookup(target, cancel, callback)
}
+// LookupPubkey finds the closest nodes to the given public key.
func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node {
if t.tab.len() == 0 {
// All nodes were dropped, refresh. The very first query will hit this
// case and run the bootstrapping logic.
<-t.tab.refresh()
}
- return unwrapNodes(t.lookup(encodePubkey(key)))
+ return unwrapNodes(t.lookup(encodePubkey(key), t.tab.closeReq, nil))
}
+// for Table
func (t *UDPv4) lookupRandom() []*enode.Node {
var target encPubkey
crand.Read(target[:])
- return unwrapNodes(t.lookup(target))
+ return unwrapNodes(t.lookup(target, t.tab.closeReq, nil))
}
+// for Table
func (t *UDPv4) lookupSelf() []*enode.Node {
- return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey)))
+ return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey), t.tab.closeReq, nil))
}
// lookup performs a network search for nodes close to the given target. It approaches the
// target by querying nodes that are closer to it on each iteration. The given target does
// not need to be an actual node identifier.
-func (t *UDPv4) lookup(targetKey encPubkey) []*node {
+func (t *UDPv4) lookup(targetKey encPubkey, cancel <-chan struct{}, nodeCallback func(*enode.Node)) []*node {
var (
target = enode.ID(crypto.Keccak256Hash(targetKey[:]))
asked = make(map[enode.ID]bool)
@@ -360,7 +375,7 @@ func (t *UDPv4) lookup(targetKey encPubkey) []*node {
if !asked[n.ID()] {
asked[n.ID()] = true
pendingQueries++
- go t.lookupWorker(n, targetKey, reply)
+ go t.lookupWorker(n, targetKey, reply, nodeCallback)
}
}
if pendingQueries == 0 {
@@ -375,7 +390,7 @@ func (t *UDPv4) lookup(targetKey encPubkey) []*node {
result.push(n, bucketSize)
}
}
- case <-t.tab.closeReq:
+ case <-cancel:
return nil // shutdown, no need to continue.
}
pendingQueries--
@@ -383,7 +398,7 @@ func (t *UDPv4) lookup(targetKey encPubkey) []*node {
return result.entries
}
-func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node) {
+func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node, callback func(*enode.Node)) {
fails := t.db.FindFails(n.ID(), n.IP())
r, err := t.findnode(n.ID(), n.addr(), targetKey)
if err == errClosed {
@@ -407,7 +422,11 @@ func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node)
// just remove those again during revalidation.
for _, n := range r {
t.tab.addSeenNode(n)
+ if callback != nil {
+ callback(unwrapNode(n))
+ }
}
+
reply <- r
}
diff --git a/p2p/discutil/iter.go b/p2p/discutil/iter.go
new file mode 100644
index 000000000000..cebd2236fd13
--- /dev/null
+++ b/p2p/discutil/iter.go
@@ -0,0 +1,239 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package discutil provides node discovery utilities.
+package discutil
+
+import (
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+// Iterator represents a sequence of nodes.
+//
+// The Next method returns the next node in the sequence. The isLive return value reports
+// whether the iterator is still open. Once closed, iterators should keep returning (nil, false).
+//
+// Close may be called concurrently with Next and Node, and interrupts Next if it is blocked.
+type Iterator interface {
+ Next() bool // moves to next node
+ Node() *enode.Node // returns current node
+ Close() // ends the iterator
+}
+
+// ReadNodes reads at most n nodes from the given iterator. The return value contains no
+// duplicates and no nil values. To prevent looping indefinitely for small repeating node
+// sequences, this function calls NextNode at most n times.
+func ReadNodes(it Iterator, n int) []*enode.Node {
+ seen := make(map[enode.ID]*enode.Node, n)
+ for i := 0; i < n && it.Next(); i++ {
+ // Remove duplicates, keeping the node with higher seq.
+ node := it.Node()
+ prevNode, ok := seen[node.ID()]
+ if ok && prevNode.Seq() > node.Seq() {
+ continue
+ }
+ seen[node.ID()] = node
+ }
+ result := make([]*enode.Node, 0, len(seen))
+ for _, node := range seen {
+ result = append(result, node)
+ }
+ return result
+}
+
+// Filter wraps an iterator such that NextNode only returns nodes for which
+// the 'check' function returns true.
+func Filter(it Iterator, check func(*enode.Node) bool) Iterator {
+ return &filterIter{it, check}
+}
+
+type filterIter struct {
+ Iterator
+ check func(*enode.Node) bool
+}
+
+func (f *filterIter) Next() bool {
+ for f.Iterator.Next() {
+ if f.check(f.Node()) {
+ return true
+ }
+ }
+ return false
+}
+
+// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends
+// only when Close is called. Source iterators added via AddSource are removed from the mix
+// when they end.
+//
+// The distribution of nodes returned by NextNode is approximately fair, i.e. FairMix
+// attempts to draw from all sources equally often. However, if a certain source is slow
+// and doesn't return a node within the configured timeout, a node from any other source
+// will be returned.
+//
+// It's safe to call AddSource and Close concurrently with NextNode.
+type FairMix struct {
+ wg sync.WaitGroup
+ fromAny chan *enode.Node
+ timeout time.Duration
+ cur *enode.Node
+
+ mu sync.Mutex
+ closed chan struct{}
+ sources []*mixSource
+ last int
+}
+
+type mixSource struct {
+ it Iterator
+ next chan *enode.Node
+}
+
+// NewFairMix creates a mixer.
+//
+// The timeout specifies how long the mixer will wait for the next fairly-chosen source
+// before giving up and taking a node from any other source. A good way to set the timeout
+// is deciding how long you'd want to wait for a node on average. Passing a negative
+// timeout disables the mixer completely fair.
+func NewFairMix(timeout time.Duration) *FairMix {
+ m := &FairMix{
+ fromAny: make(chan *enode.Node),
+ closed: make(chan struct{}),
+ timeout: timeout,
+ }
+ return m
+}
+
+// AddSource adds a source of nodes.
+func (m *FairMix) AddSource(it Iterator) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.closed == nil {
+ return
+ }
+ m.wg.Add(1)
+ source := &mixSource{it, make(chan *enode.Node)}
+ m.sources = append(m.sources, source)
+ go m.runSource(m.closed, source)
+}
+
+// Close shuts down the mixer and all current sources.
+// Calling this is required to release resources associated with the mixer.
+func (m *FairMix) Close() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.closed == nil {
+ return
+ }
+ for _, s := range m.sources {
+ s.it.Close()
+ }
+ close(m.closed)
+ m.wg.Wait()
+ close(m.fromAny)
+ m.sources = nil
+ m.closed = nil
+}
+
+// NextNode returns a node from a random source.
+func (m *FairMix) Next() bool {
+ m.cur = nil
+
+ var timeout <-chan time.Time
+ if m.timeout >= 0 {
+ timer := time.NewTimer(m.timeout)
+ timeout = timer.C
+ defer timer.Stop()
+ }
+ for {
+ source := m.pickSource()
+ if source == nil {
+ return m.nextFromAny()
+ }
+ select {
+ case n, ok := <-source.next:
+ if ok {
+ m.cur = n
+ return true
+ }
+ // This source has ended.
+ m.deleteSource(source)
+ case <-timeout:
+ return m.nextFromAny()
+ }
+ }
+}
+
+// Node returns the current node.
+func (m *FairMix) Node() *enode.Node {
+ return m.cur
+}
+
+// nextFromAny is used when there are no sources or when the 'fair' choice
+// doesn't turn up a node quickly enough.
+func (m *FairMix) nextFromAny() bool {
+ n, ok := <-m.fromAny
+ if ok {
+ m.cur = n
+ }
+ return ok
+}
+
+// pickSource chooses the next source to read from, cycling through them in order.
+func (m *FairMix) pickSource() *mixSource {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if len(m.sources) == 0 {
+ return nil
+ }
+ m.last = (m.last + 1) % len(m.sources)
+ return m.sources[m.last]
+}
+
+// deleteSource deletes a source.
+func (m *FairMix) deleteSource(s *mixSource) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ for i := range m.sources {
+ if m.sources[i] == s {
+ copy(m.sources[i:], m.sources[i+1:])
+ m.sources[len(m.sources)-1] = nil
+ m.sources = m.sources[:len(m.sources)-1]
+ break
+ }
+ }
+}
+
+// runSource reads a single source in a loop.
+func (m *FairMix) runSource(closed chan struct{}, s *mixSource) {
+ defer m.wg.Done()
+ defer close(s.next)
+ for s.it.Next() {
+ n := s.it.Node()
+ select {
+ case s.next <- n:
+ case m.fromAny <- n:
+ case <-closed:
+ return
+ }
+ }
+}
diff --git a/p2p/discutil/iter_test.go b/p2p/discutil/iter_test.go
new file mode 100644
index 000000000000..4d65d67a223f
--- /dev/null
+++ b/p2p/discutil/iter_test.go
@@ -0,0 +1,275 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package discutil
+
+import (
+ "encoding/binary"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+)
+
+func TestReadNodes(t *testing.T) {
+ nodes := ReadNodes(new(genIter), 10)
+ checkNodes(t, nodes, 10)
+}
+
+// This test checks that ReadNodes terminates when reading N nodes from an iterator
+// which returns less than N nodes in an endless cycle.
+func TestReadNodesCycle(t *testing.T) {
+ iter := &callCountIter{
+ Iterator: cycleNodes(
+ testNode(0, 0),
+ testNode(1, 0),
+ testNode(2, 0),
+ ),
+ }
+ nodes := ReadNodes(iter, 10)
+ checkNodes(t, nodes, 3)
+ if iter.count != 10 {
+ t.Fatalf("%d calls to Next, want %d", iter.count, 100)
+ }
+}
+
+func checkNodes(t *testing.T, nodes []*enode.Node, wantLen int) {
+ if len(nodes) != wantLen {
+ t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen)
+ return
+ }
+ seen := make(map[enode.ID]bool)
+ for i, e := range nodes {
+ if e == nil {
+ t.Errorf("nil node at index %d", i)
+ return
+ }
+ if seen[e.ID()] {
+ t.Errorf("slice has duplicate node %v", e.ID())
+ return
+ }
+ seen[e.ID()] = true
+ }
+}
+
+// This test checks fairness of FairMix in the happy case where all sources return nodes
+// within the context's deadline.
+func TestFairMix(t *testing.T) {
+ for i := 0; i < 500; i++ {
+ testMixerFairness(t)
+ }
+}
+
+func testMixerFairness(t *testing.T) {
+ mix := NewFairMix(1 * time.Second)
+ mix.AddSource(&genIter{index: 1})
+ mix.AddSource(&genIter{index: 2})
+ mix.AddSource(&genIter{index: 3})
+ defer mix.Close()
+
+ nodes := ReadNodes(mix, 500)
+ checkNodes(t, nodes, 500)
+
+ // Verify that the nodes slice contains an approximately equal number of nodes
+ // from each source.
+ d := idPrefixDistribution(nodes)
+ for _, count := range d {
+ if approxEqual(count, len(nodes)/3, 30) {
+ t.Fatalf("ID distribution is unfair: %v", d)
+ }
+ }
+}
+
+// This test checks that FairMix falls back to an alternative source when
+// the 'fair' choice doesn't return a node within the timeout.
+func TestFairMixNextFromAll(t *testing.T) {
+ mix := NewFairMix(1 * time.Millisecond)
+ mix.AddSource(&genIter{index: 1})
+ mix.AddSource(cycleNodes())
+ defer mix.Close()
+
+ nodes := ReadNodes(mix, 500)
+ checkNodes(t, nodes, 500)
+
+ d := idPrefixDistribution(nodes)
+ if len(d) > 1 || d[1] != len(nodes) {
+ t.Fatalf("wrong ID distribution: %v", d)
+ }
+}
+
+// This test ensures FairMix works for Next with no sources.
+func TestFairMixEmpty(t *testing.T) {
+ var (
+ mix = NewFairMix(1 * time.Second)
+ testN = testNode(1, 1)
+ ch = make(chan *enode.Node)
+ )
+ defer mix.Close()
+
+ go func() {
+ mix.Next()
+ ch <- mix.Node()
+ }()
+
+ mix.AddSource(cycleNodes(testN))
+ if n := <-ch; n != testN {
+ t.Errorf("got wrong node: %v", n)
+ }
+}
+
+// This test checks closing a source while Next runs.
+func TestFairMixRemoveSource(t *testing.T) {
+ mix := NewFairMix(1 * time.Second)
+ source := cycleNodes()
+ source.Close()
+ mix.AddSource(source)
+
+ if mix.Next() {
+ t.Fatal("Next should've returned false")
+ }
+ if len(mix.sources) != 0 {
+ t.Fatalf("have %d sources, want zero", len(mix.sources))
+ }
+}
+
+func TestFairMixClose(t *testing.T) {
+ for i := 0; i < 20 && !t.Failed(); i++ {
+ testMixerClose(t)
+ }
+}
+
+func testMixerClose(t *testing.T) {
+ mix := NewFairMix(-1)
+ mix.AddSource(cycleNodes())
+ mix.AddSource(cycleNodes())
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ if mix.Next() {
+ t.Error("Next returned true")
+ }
+ }()
+ // This call is supposed to make it more likely that NextNode is
+ // actually executing by the time we call Close.
+ runtime.Gosched()
+
+ mix.Close()
+ select {
+ case <-done:
+ case <-time.After(3 * time.Second):
+ t.Fatal("Next didn't unblock on Close")
+ }
+
+ mix.Close() // shouldn't crash
+}
+
+func idPrefixDistribution(nodes []*enode.Node) map[uint32]int {
+ d := make(map[uint32]int)
+ for _, node := range nodes {
+ id := node.ID()
+ d[binary.BigEndian.Uint32(id[:4])]++
+ }
+ return d
+}
+
+func approxEqual(x, y, ε int) bool {
+ if y > x {
+ x, y = y, x
+ }
+ return x-y > ε
+}
+
+// genIter creates fake nodes with numbered IDs based on 'index' and 'gen'
+type genIter struct {
+ node *enode.Node
+ index, gen uint32
+}
+
+func (s *genIter) Next() bool {
+ index := atomic.LoadUint32(&s.index)
+ if index == ^uint32(0) {
+ s.node = nil
+ return false
+ }
+ s.node = testNode(uint64(index)<<32|uint64(s.gen), 0)
+ s.gen++
+ return true
+}
+
+func (s *genIter) Node() *enode.Node {
+ return s.node
+}
+
+func (s *genIter) Close() {
+ s.index = ^uint32(0)
+}
+
+func testNode(id, seq uint64) *enode.Node {
+ var nodeID enode.ID
+ binary.BigEndian.PutUint64(nodeID[:], id)
+ r := new(enr.Record)
+ r.SetSeq(seq)
+ return enode.SignNull(r, nodeID)
+}
+
+// cycleNodes is an interator that cycles through the given slice.
+func cycleNodes(nodes ...*enode.Node) Iterator {
+ return &cycleIter{nodes: nodes}
+}
+
+type cycleIter struct {
+ cur *enode.Node
+ mu sync.Mutex
+ index int
+ nodes []*enode.Node
+}
+
+func (s *cycleIter) Next() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if len(s.nodes) == 0 {
+ return false
+ }
+ s.cur = s.nodes[s.index]
+ s.index = (s.index + 1) % len(s.nodes)
+ return true
+}
+
+func (s *cycleIter) Node() *enode.Node {
+ return s.nodes[s.index]
+}
+
+func (s *cycleIter) Close() {
+ s.mu.Lock()
+ s.nodes = nil
+ s.mu.Unlock()
+}
+
+// callCountIter counts calls to NextNode.
+type callCountIter struct {
+ Iterator
+ count int
+}
+
+func (it *callCountIter) Next() bool {
+ it.count++
+ return it.Iterator.Next()
+}
diff --git a/p2p/protocol.go b/p2p/protocol.go
index 9438ab8e47dd..30496709311c 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -19,6 +19,7 @@ package p2p
import (
"fmt"
+ "github.com/ethereum/go-ethereum/p2p/discutil"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
)
@@ -54,6 +55,11 @@ type Protocol struct {
// but returns nil, it is assumed that the protocol handshake is still running.
PeerInfo func(id enode.ID) interface{}
+ // DialCandidates, if non-nil, is a way to tell Server about protocol-specific nodes
+ // that should be dialed. The server continuously reads nodes from the iterator and
+ // attempts to create connections to them.
+ DialCandidates discutil.Iterator
+
// Attributes contains protocol specific information for the node record.
Attributes []enr.Entry
}
diff --git a/p2p/server.go b/p2p/server.go
index b7340a5ea718..798d3162874a 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -35,6 +35,7 @@ import (
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/p2p/discutil"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
@@ -167,16 +168,20 @@ type Server struct {
lock sync.Mutex // protects running
running bool
- nodedb *enode.DB
- localnode *enode.LocalNode
- ntab discoverTable
listener net.Listener
ourHandshake *protoHandshake
- DiscV5 *discv5.Network
loopWG sync.WaitGroup // loop, listenLoop
peerFeed event.Feed
log log.Logger
+ nodedb *enode.DB
+ localnode *enode.LocalNode
+ ntab *discover.UDPv4
+ DiscV5 *discv5.Network
+ discmix *discutil.FairMix
+
+ staticNodeResolver nodeResolver
+
// Channels into the run loop.
quit chan struct{}
addstatic chan *enode.Node
@@ -465,7 +470,7 @@ func (srv *Server) Start() (err error) {
}
dynPeers := srv.maxDialedConns()
- dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config)
+ dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config)
srv.loopWG.Add(1)
go srv.run(dialer)
return nil
@@ -517,6 +522,8 @@ func (srv *Server) setupLocalNode() error {
}
func (srv *Server) setupDiscovery() error {
+ srv.discmix = discutil.NewFairMix(fallbackInterval)
+
if srv.NoDiscovery && !srv.DiscoveryV5 {
return nil
}
@@ -558,7 +565,10 @@ func (srv *Server) setupDiscovery() error {
return err
}
srv.ntab = ntab
+ srv.discmix.AddSource(ntab.RandomNodes(nil))
+ srv.staticNodeResolver = ntab
}
+
// Discovery V5
if srv.DiscoveryV5 {
var ntab *discv5.Network
diff --git a/p2p/server_test.go b/p2p/server_test.go
index e8bc627e1d30..9f01ffc16bf6 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -234,7 +234,6 @@ func TestServerTaskScheduling(t *testing.T) {
localnode: enode.NewLocalNode(db, newkey()),
nodedb: db,
quit: make(chan struct{}),
- ntab: fakeTable{},
running: true,
log: log.New(),
}
@@ -282,7 +281,6 @@ func TestServerManyTasks(t *testing.T) {
quit: make(chan struct{}),
localnode: enode.NewLocalNode(db, newkey()),
nodedb: db,
- ntab: fakeTable{},
running: true,
log: log.New(),
}