diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 1e56687a6c78..28b0a05a2183 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -38,6 +38,7 @@ var ( discv4PingCommand, discv4RequestRecordCommand, discv4ResolveCommand, + discv4RandomWalkCommand, }, } discv4PingCommand = cli.Command{ @@ -56,6 +57,12 @@ var ( Action: discv4Resolve, Flags: []cli.Flag{bootnodesFlag}, } + discv4RandomWalkCommand = cli.Command{ + Name: "randomwalk", + Usage: "Prints random nodes found in the DHT", + Action: discv4RandomNodes, + Flags: []cli.Flag{bootnodesFlag}, + } ) var bootnodesFlag = cli.StringFlag{ @@ -104,6 +111,24 @@ func discv4Resolve(ctx *cli.Context) error { return nil } +func discv4RandomNodes(ctx *cli.Context) error { + bootnodes, err := parseBootnodes(ctx) + if err != nil { + return err + } + disc, err := startV4(bootnodes) + if err != nil { + return err + } + defer disc.Close() + + it := disc.RandomNodes(nil) + for it.Next() { + fmt.Println(it.Node()) + } + return nil +} + func getNodeArgAndStartV4(ctx *cli.Context) (*enode.Node, *discover.UDPv4, error) { if ctx.NArg() != 1 { return nil, nil, fmt.Errorf("missing node as command-line argument") diff --git a/p2p/dial.go b/p2p/dial.go index 8dee5063f1d5..b2e20ba338f5 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -23,6 +23,7 @@ import ( "time" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discutil" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/netutil" ) @@ -33,12 +34,10 @@ const ( // private networks. dialHistoryExpiration = inboundThrottleTime + 5*time.Second - // Discovery lookups are throttled and can only run - // once every few seconds. - lookupInterval = 4 * time.Second + // Timeout for NextNode on the discovery iterator. + discoveryTimeout = 4 * time.Second - // If no peers are found for this amount of time, the initial bootnodes are - // attempted to be connected. + // If no peers are found for this amount of time, the initial bootnodes are dialed. fallbackInterval = 20 * time.Second // Endpoint resolution is throttled with bounded backoff. @@ -52,6 +51,10 @@ type NodeDialer interface { Dial(*enode.Node) (net.Conn, error) } +type nodeResolver interface { + Resolve(*enode.Node) *enode.Node +} + // TCPDialer implements the NodeDialer interface by using a net.Dialer to // create TCP connections to nodes in the network type TCPDialer struct { @@ -69,7 +72,6 @@ func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) { // of the main loop in Server.run. type dialstate struct { maxDynDials int - ntab discoverTable netrestrict *netutil.Netlist self enode.ID bootnodes []*enode.Node // default dials when there are no peers @@ -79,55 +81,23 @@ type dialstate struct { lookupRunning bool dialing map[enode.ID]connFlag lookupBuf []*enode.Node // current discovery lookup results - randomNodes []*enode.Node // filled from Table static map[enode.ID]*dialTask hist expHeap } -type discoverTable interface { - Close() - Resolve(*enode.Node) *enode.Node - LookupRandom() []*enode.Node - ReadRandomNodes([]*enode.Node) int -} - type task interface { Do(*Server) } -// A dialTask is generated for each node that is dialed. Its -// fields cannot be accessed while the task is running. -type dialTask struct { - flags connFlag - dest *enode.Node - lastResolved time.Time - resolveDelay time.Duration -} - -// discoverTask runs discovery table operations. -// Only one discoverTask is active at any time. -// discoverTask.Do performs a random lookup. -type discoverTask struct { - results []*enode.Node -} - -// A waitExpireTask is generated if there are no other tasks -// to keep the loop in Server.run ticking. -type waitExpireTask struct { - time.Duration -} - -func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate { +func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate { s := &dialstate{ maxDynDials: maxdyn, - ntab: ntab, self: self, netrestrict: cfg.NetRestrict, log: cfg.Logger, static: make(map[enode.ID]*dialTask), dialing: make(map[enode.ID]connFlag), bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), - randomNodes: make([]*enode.Node, maxdyn/2), } copy(s.bootnodes, cfg.BootstrapNodes) if s.log == nil { @@ -151,10 +121,6 @@ func (s *dialstate) removeStatic(n *enode.Node) { } func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { - if s.start.IsZero() { - s.start = now - } - var newtasks []task addDial := func(flag connFlag, n *enode.Node) bool { if err := s.checkDial(n, peers); err != nil { @@ -166,20 +132,9 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti return true } - // Compute number of dynamic dials necessary at this point. - needDynDials := s.maxDynDials - for _, p := range peers { - if p.rw.is(dynDialedConn) { - needDynDials-- - } - } - for _, flag := range s.dialing { - if flag&dynDialedConn != 0 { - needDynDials-- - } + if s.start.IsZero() { + s.start = now } - - // Expire the dial history on every invocation. s.hist.expire(now) // Create dials for static nodes if they are not connected. @@ -194,6 +149,20 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti newtasks = append(newtasks, t) } } + + // Compute number of dynamic dials needed. + needDynDials := s.maxDynDials + for _, p := range peers { + if p.rw.is(dynDialedConn) { + needDynDials-- + } + } + for _, flag := range s.dialing { + if flag&dynDialedConn != 0 { + needDynDials-- + } + } + // If we don't have any peers whatsoever, try to dial a random bootnode. This // scenario is useful for the testnet (and private networks) where the discovery // table might be full of mostly bad peers, making it hard to find good ones. @@ -201,24 +170,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti bootnode := s.bootnodes[0] s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...) s.bootnodes = append(s.bootnodes, bootnode) - if addDial(dynDialedConn, bootnode) { needDynDials-- } } - // Use random nodes from the table for half of the necessary - // dynamic dials. - randomCandidates := needDynDials / 2 - if randomCandidates > 0 { - n := s.ntab.ReadRandomNodes(s.randomNodes) - for i := 0; i < randomCandidates && i < n; i++ { - if addDial(dynDialedConn, s.randomNodes[i]) { - needDynDials-- - } - } - } - // Create dynamic dials from random lookup results, removing tried - // items from the result buffer. + + // Create dynamic dials from discovery results. i := 0 for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { if addDial(dynDialedConn, s.lookupBuf[i]) { @@ -226,10 +183,11 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti } } s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] + // Launch a discovery lookup if more candidates are needed. if len(s.lookupBuf) < needDynDials && !s.lookupRunning { s.lookupRunning = true - newtasks = append(newtasks, &discoverTask{}) + newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)}) } // Launch a timer to wait for the next node to expire if all @@ -279,6 +237,15 @@ func (s *dialstate) taskDone(t task, now time.Time) { } } +// A dialTask is generated for each node that is dialed. Its +// fields cannot be accessed while the task is running. +type dialTask struct { + flags connFlag + dest *enode.Node + lastResolved time.Time + resolveDelay time.Duration +} + func (t *dialTask) Do(srv *Server) { if t.dest.Incomplete() { if !t.resolve(srv) { @@ -304,7 +271,7 @@ func (t *dialTask) Do(srv *Server) { // discovery network with useless queries for nodes that don't exist. // The backoff delay resets when the node is found. func (t *dialTask) resolve(srv *Server) bool { - if srv.ntab == nil { + if srv.staticNodeResolver == nil { srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") return false } @@ -314,7 +281,7 @@ func (t *dialTask) resolve(srv *Server) bool { if time.Since(t.lastResolved) < t.resolveDelay { return false } - resolved := srv.ntab.Resolve(t.dest) + resolved := srv.staticNodeResolver.Resolve(t.dest) t.lastResolved = time.Now() if resolved == nil { t.resolveDelay *= 2 @@ -350,26 +317,34 @@ func (t *dialTask) String() string { return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) } +// discoverTask runs discovery table operations. +// Only one discoverTask is active at any time. +// discoverTask.Do performs a random lookup. +type discoverTask struct { + want int + results []*enode.Node +} + func (t *discoverTask) Do(srv *Server) { - // newTasks generates a lookup task whenever dynamic dials are - // necessary. Lookups need to take some time, otherwise the - // event loop spins too fast. - next := srv.lastLookup.Add(lookupInterval) - if now := time.Now(); now.Before(next) { - time.Sleep(next.Sub(now)) - } - srv.lastLookup = time.Now() - t.results = srv.ntab.LookupRandom() + t.results = discutil.ReadNodes(srv.discmix, t.want) } func (t *discoverTask) String() string { - s := "discovery lookup" + s := "discovery query" if len(t.results) > 0 { s += fmt.Sprintf(" (%d results)", len(t.results)) + } else { + s += fmt.Sprintf(" (want %d)", t.want) } return s } +// A waitExpireTask is generated if there are no other tasks +// to keep the loop in Server.run ticking. +type waitExpireTask struct { + time.Duration +} + func (t waitExpireTask) Do(*Server) { time.Sleep(t.Duration) } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index de8fc4a6e3e6..6189ec4d0b85 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -73,7 +73,7 @@ func runDialTest(t *testing.T, test dialtest) { t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) } - t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new))) + t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new))) // Time advances by 16 seconds on every round. vtime = vtime.Add(16 * time.Second) @@ -81,19 +81,11 @@ func runDialTest(t *testing.T, test dialtest) { } } -type fakeTable []*enode.Node - -func (t fakeTable) Self() *enode.Node { return new(enode.Node) } -func (t fakeTable) Close() {} -func (t fakeTable) LookupRandom() []*enode.Node { return nil } -func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil } -func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) } - // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 5, config), + init: newDialState(enode.ID{}, 5, config), rounds: []round{ // A discovery query is launched. { @@ -102,7 +94,9 @@ func TestDialStateDynDial(t *testing.T) { {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, - new: []task{&discoverTask{}}, + new: []task{ + &discoverTask{want: 3}, + }, }, // Dynamic dials are launched when it completes. { @@ -188,7 +182,7 @@ func TestDialStateDynDial(t *testing.T) { }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, - &discoverTask{}, + &discoverTask{want: 2}, }, }, // Peer 7 is connected, but there still aren't enough dynamic peers @@ -218,7 +212,7 @@ func TestDialStateDynDial(t *testing.T) { &discoverTask{}, }, new: []task{ - &discoverTask{}, + &discoverTask{want: 2}, }, }, }, @@ -235,35 +229,37 @@ func TestDialStateDynDialBootnode(t *testing.T) { }, Logger: testlog.Logger(t, log.LvlTrace), } - table := fakeTable{ - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), - newNode(uintID(7), nil), - newNode(uintID(8), nil), - } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 5, config), + init: newDialState(enode.ID{}, 5, config), rounds: []round{ - // 2 dynamic dials attempted, bootnodes pending fallback interval { new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{}, + &discoverTask{want: 5}, }, }, - // No dials succeed, bootnodes still pending fallback interval { done: []task{ + &discoverTask{ + results: []*enode.Node{ + newNode(uintID(4), nil), + newNode(uintID(5), nil), + }, + }, + }, + new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, + &discoverTask{want: 3}, }, }, // No dials succeed, bootnodes still pending fallback interval {}, - // No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached + // 1 bootnode attempted as fallback interval was reached { + done: []task{ + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, + }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, }, @@ -275,15 +271,12 @@ func TestDialStateDynDialBootnode(t *testing.T) { }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 3rd bootnode is attempted { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, @@ -293,115 +286,19 @@ func TestDialStateDynDialBootnode(t *testing.T) { { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - }, - new: []task{}, - }, - // Random dial succeeds, no more bootnodes are attempted - { - new: []task{ - &waitExpireTask{3 * time.Second}, - }, - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - }, - }, - }, - }) -} - -func TestDialStateDynDialFromTable(t *testing.T) { - // This table always returns the same random nodes - // in the order given below. - table := fakeTable{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), - newNode(uintID(7), nil), - newNode(uintID(8), nil), - } - - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}), - rounds: []round{ - // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. - { - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{}, - }, - }, - // Dialing nodes 1,2 succeeds. Dials from the lookup are launched. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, &discoverTask{results: []*enode.Node{ - newNode(uintID(10), nil), - newNode(uintID(11), nil), - newNode(uintID(12), nil), + newNode(uintID(6), nil), }}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, - &discoverTask{}, - }, - }, - // Dialing nodes 3,4,5 fails. The dials from the lookup succeed. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, + &discoverTask{want: 4}, }, }, - // Waiting for expiry. No waitExpireTask is launched because the - // discovery query is still running. + // Random dial succeeds, no more bootnodes are attempted { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, - }, - }, - // Nodes 3,4 are not tried again because only the first two - // returned random nodes (nodes 1,2) are tried and they're - // already connected. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}}, }, }, }, @@ -416,11 +313,11 @@ func newNode(id enode.ID, ip net.IP) *enode.Node { return enode.SignNull(&r, id) } -// This test checks that candidates that do not match the netrestrict list are not dialed. +// // This test checks that candidates that do not match the netrestrict list are not dialed. func TestDialStateNetRestrict(t *testing.T) { // This table always returns the same random nodes // in the order given below. - table := fakeTable{ + nodes := []*enode.Node{ newNode(uintID(1), net.ParseIP("127.0.0.1")), newNode(uintID(2), net.ParseIP("127.0.0.2")), newNode(uintID(3), net.ParseIP("127.0.0.3")), @@ -434,12 +331,23 @@ func TestDialStateNetRestrict(t *testing.T) { restrict.Add("127.0.2.0/24") runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}), + init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}), rounds: []round{ { new: []task{ - &dialTask{flags: dynDialedConn, dest: table[4]}, - &discoverTask{}, + &discoverTask{want: 10}, + }, + }, + { + done: []task{ + &discoverTask{results: nodes}, + }, + new: []task{ + &dialTask{flags: dynDialedConn, dest: nodes[4]}, + &dialTask{flags: dynDialedConn, dest: nodes[5]}, + &dialTask{flags: dynDialedConn, dest: nodes[6]}, + &dialTask{flags: dynDialedConn, dest: nodes[7]}, + &discoverTask{want: 6}, }, }, }, @@ -459,7 +367,7 @@ func TestDialStateStaticDial(t *testing.T) { Logger: testlog.Logger(t, log.LvlTrace), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 0, config), + init: newDialState(enode.ID{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -544,7 +452,7 @@ func TestDialStateCache(t *testing.T) { Logger: testlog.Logger(t, log.LvlTrace), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 0, config), + init: newDialState(enode.ID{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -618,8 +526,8 @@ func TestDialResolve(t *testing.T) { Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, } resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) - table := &resolveMock{answer: resolved} - state := newDialState(enode.ID{}, table, 0, config) + resolver := &resolveMock{answer: resolved} + state := newDialState(enode.ID{}, 0, config) // Check that the task is generated with an incomplete ID. dest := newNode(uintID(1), nil) @@ -630,10 +538,14 @@ func TestDialResolve(t *testing.T) { } // Now run the task, it should resolve the ID once. - srv := &Server{ntab: table, log: config.Logger, Config: *config} + srv := &Server{ + Config: *config, + log: config.Logger, + staticNodeResolver: resolver, + } tasks[0].Do(srv) - if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { - t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) + if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) { + t.Fatalf("wrong resolve calls, got %v", resolver.calls) } // Report it as done to the dialer, which should update the static node record. @@ -666,18 +578,13 @@ func uintID(i uint32) enode.ID { return id } -// implements discoverTable for TestDialResolve +// for TestDialResolve type resolveMock struct { - resolveCalls []*enode.Node - answer *enode.Node + calls []*enode.Node + answer *enode.Node } func (t *resolveMock) Resolve(n *enode.Node) *enode.Node { - t.resolveCalls = append(t.resolveCalls, n) + t.calls = append(t.calls, n) return t.answer } - -func (t *resolveMock) Self() *enode.Node { return new(enode.Node) } -func (t *resolveMock) Close() {} -func (t *resolveMock) LookupRandom() []*enode.Node { return nil } -func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 } diff --git a/p2p/discover/common.go b/p2p/discover/common.go index 3c080359fdf8..5ec323809ba1 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -18,13 +18,16 @@ package discover import ( "crypto/ecdsa" + "fmt" "net" + "sync" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/netutil" ) +// UDPConn is a network connection on which discovery can operate. type UDPConn interface { ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) @@ -32,7 +35,7 @@ type UDPConn interface { LocalAddr() net.Addr } -// Config holds Table-related settings. +// Config holds settings for the discovery listener. type Config struct { // These settings are required and configure the UDP listener: PrivateKey *ecdsa.PrivateKey @@ -50,8 +53,211 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { } // ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled -// channel if configured. +// channel if configured. This is exported for internal use, do not use this type. type ReadPacket struct { Data []byte Addr *net.UDPAddr } + +type lookupFunc func(cancel <-chan struct{}, seenNode func(*enode.Node)) + +// lookupWalker performs recursive lookups, walking the DHT. +// It manages a set iterators which receive lookup results as they are found. +type lookupWalker struct { + lookup lookupFunc + closeCh chan struct{} + + mu sync.Mutex + cond *sync.Cond + wg sync.WaitGroup + iters map[*lookupIterator]struct{} +} + +func newLookupWalker(fn lookupFunc) *lookupWalker { + w := &lookupWalker{ + lookup: fn, + closeCh: make(chan struct{}), + iters: make(map[*lookupIterator]struct{}), + } + w.cond = sync.NewCond(&w.mu) + w.wg.Add(1) + go w.loop() + return w +} + +func (w *lookupWalker) close() { + close(w.closeCh) + w.wg.Wait() +} + +// loop schedules lookups. It ensures a lookup is running while +// any live iterator needs more nodes. +func (w *lookupWalker) loop() { + var ( + done = make(chan struct{}) + cancel = make(chan struct{}) + running bool + ) + for { + if !running { + go w.runLookup(cancel, done) + } + select { + case <-done: + case <-w.closeCh: + if running { + close(cancel) + <-done + } + goto shutdown + } + } + +shutdown: + w.mu.Lock() + defer w.mu.Unlock() + for it := range w.iters { + it.close() + } + w.wg.Done() +} + +func (w *lookupWalker) runLookup(cancel, done chan struct{}) { + w.lookup(cancel, w.foundNode) + done <- struct{}{} +} + +func (w *lookupWalker) foundNode(n *enode.Node) { + w.mu.Lock() + defer w.mu.Unlock() + for it := range w.iters { + it.deliver(n) + fmt.Println("delivered", len(it.buf), it.needsNodes()) + } + for !anyIterNeedsNodes(w.iters) { + w.cond.Wait() + } +} + +func (w *lookupWalker) newIterator(filter filterFunc) *lookupIterator { + it := newLookupIterator(w, filter) + it.walker.add(it) + return it +} + +func (w *lookupWalker) add(it *lookupIterator) { + w.mu.Lock() + defer w.mu.Unlock() + w.iters[it] = struct{}{} + w.unblockLookup() +} + +func (w *lookupWalker) remove(it *lookupIterator) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.iters, it) + w.unblockLookup() +} + +func (w *lookupWalker) unblockLookup() { + w.cond.Signal() +} + +func anyIterNeedsNodes(iters map[*lookupIterator]struct{}) bool { + for it := range iters { + if it.needsNodes() { + return true + } + } + return false +} + +// lookupIterator is a sequence of discovered nodes. +type lookupIterator struct { + cur *enode.Node + walker *lookupWalker + filter filterFunc + mu sync.Mutex + cond *sync.Cond + buf []*enode.Node +} + +const lookupIteratorBuffer = 100 + +type filterFunc func(*enode.Node) bool + +func newLookupIterator(w *lookupWalker, filter filterFunc) *lookupIterator { + if filter == nil { + filter = func(*enode.Node) bool { return true } + } + it := &lookupIterator{ + walker: w, + filter: filter, + buf: make([]*enode.Node, 0, lookupIteratorBuffer), + } + it.cond = sync.NewCond(&it.mu) + return it +} + +func (it *lookupIterator) Next() bool { + it.cur = nil + + // Wait for the buffer to be filled. + it.mu.Lock() + defer it.mu.Unlock() + for it.buf != nil && len(it.buf) == 0 { + it.cond.Wait() + } + if it.buf == nil { + return false // closed + } + it.cur = it.buf[0] + copy(it.buf, it.buf[1:]) + it.buf = it.buf[:len(it.buf)-1] + fmt.Println("read node", len(it.buf)) + it.walker.unblockLookup() + return true +} + +func (it *lookupIterator) Node() *enode.Node { + return it.cur +} + +func (it *lookupIterator) Close() { + it.walker.remove(it) + it.close() +} + +func (it *lookupIterator) close() { + it.mu.Lock() + defer it.mu.Unlock() + + if it.buf != nil { + it.buf = nil + it.cond.Signal() + } +} + +// deliver places a node into the iterator buffer. +func (it *lookupIterator) deliver(n *enode.Node) bool { + it.mu.Lock() + defer it.mu.Unlock() + + if it.buf == nil || !it.filter(n) { + return true + } + if len(it.buf) == cap(it.buf) { + return false + } + it.buf = append(it.buf, n) + it.cond.Signal() + return true +} + +// needsNodes reports whether the iterator is low on nodes. +func (it *lookupIterator) needsNodes() bool { + it.mu.Lock() + defer it.mu.Unlock() + + return len(it.buf) < lookupIteratorBuffer/3 +} diff --git a/p2p/discover/common_test.go b/p2p/discover/common_test.go new file mode 100644 index 000000000000..e257d9938c18 --- /dev/null +++ b/p2p/discover/common_test.go @@ -0,0 +1,160 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "encoding/binary" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p/discutil" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +// This test checks basic operation of the lookup iterator. +func TestLookupIterator(t *testing.T) { + var ( + test = newLookupWalkerTest() + testNodes = makeTestNodes(lookupIteratorBuffer) + wg sync.WaitGroup + ) + + testIterator := func(it discutil.Iterator) { + defer wg.Done() + + // Check reading nodes: + nodes := discutil.ReadNodes(it, 20) + sortByID(nodes) + if err := checkNodesEqual(nodes, testNodes[:20]); err != nil { + t.Error(err) + } + nodes = discutil.ReadNodes(it, 20) + sortByID(nodes) + if err := checkNodesEqual(nodes, testNodes[20:40]); err != nil { + t.Error(err) + } + + // Check close: + it.Close() + if it.Next() { + t.Error("Next returned true after close") + } + if it.Node() != nil { + t.Error("iterator has non-nil node after close") + } + it.Close() // shouldn't crash + } + + for i := 0; i < 10; i++ { + wg.Add(1) + go testIterator(test.newIterator(nil)) + } + + test.serveOneLookup(testNodes[:10]) + test.serveOneLookup(testNodes[10:20]) + test.serveOneLookup(testNodes[20:40]) + wg.Wait() + + test.close() +} + +func TestLookupIteratorClose(t *testing.T) { + test := newLookupWalkerTest() + defer test.close() + it := test.newIterator(nil) + + go func() { + time.Sleep(200 * time.Millisecond) + it.Close() + }() + it.Next() +} + +// This test checks that the iterator kicks off a lookup when Next is called. +func TestLookupIteratorDrained(t *testing.T) { + var ( + test = newLookupWalkerTest() + it = test.newIterator(nil) + testNodes = makeTestNodes(2 * lookupIteratorBuffer) + ) + + test.serveOneLookup(testNodes[:lookupIteratorBuffer]) + nodes := discutil.ReadNodes(it, lookupIteratorBuffer) + sortByID(nodes) + if err := checkNodesEqual(nodes, testNodes[:lookupIteratorBuffer]); err != nil { + t.Fatal(err) + } + + // Here the iterator buffer is drained and no lookup is running. + + // Request more nodes. This needs to start another lookup. + go test.serveOneLookup(testNodes[lookupIteratorBuffer:]) + nodes = discutil.ReadNodes(it, 10) + sortByID(nodes) + if err := checkNodesEqual(nodes, testNodes[lookupIteratorBuffer:lookupIteratorBuffer+10]); err != nil { + t.Fatal(err) + } +} + +func makeTestNodes(n int) []*enode.Node { + nodes := make([]*enode.Node, n) + for i := range nodes { + var nodeID enode.ID + binary.BigEndian.PutUint64(nodeID[:], uint64(i)) + nodes[i] = enode.SignNull(new(enr.Record), nodeID) + } + return nodes +} + +type lookupWalkerTest struct { + *lookupWalker + running int32 + nodes chan []*enode.Node +} + +func newLookupWalkerTest() *lookupWalkerTest { + wt := &lookupWalkerTest{nodes: make(chan []*enode.Node)} + wt.lookupWalker = newLookupWalker(wt.lookupFunc) + return wt +} + +// serveOneLookup allows one lookupFunc call to happen and makes it find +// the given nodes. +func (t *lookupWalkerTest) serveOneLookup(nodes []*enode.Node) { + t.nodes <- nodes + <-t.nodes +} + +func (t *lookupWalkerTest) lookupFunc(cancel <-chan struct{}, callback func(*enode.Node)) { + if atomic.AddInt32(&t.running, 1) != 1 { + panic("spawned more than one instance of lookupFunc") + } + defer atomic.AddInt32(&t.running, -1) + + select { + case nodes := <-t.nodes: + for _, n := range nodes { + callback(n) + } + t.nodes <- nil + case <-cancel: + return + } +} diff --git a/p2p/discover/table.go b/p2p/discover/table.go index e0a46792b44b..e5a5793e358f 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -147,35 +147,18 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { tab.mutex.Lock() defer tab.mutex.Unlock() - // Find all non-empty buckets and get a fresh slice of their entries. - var buckets [][]*node + var nodes []*enode.Node for _, b := range &tab.buckets { - if len(b.entries) > 0 { - buckets = append(buckets, b.entries) + for _, n := range b.entries { + nodes = append(nodes, unwrapNode(n)) } } - if len(buckets) == 0 { - return 0 - } - // Shuffle the buckets. - for i := len(buckets) - 1; i > 0; i-- { - j := tab.rand.Intn(len(buckets)) - buckets[i], buckets[j] = buckets[j], buckets[i] - } - // Move head of each bucket into buf, removing buckets that become empty. - var i, j int - for ; i < len(buf); i, j = i+1, (j+1)%len(buckets) { - b := buckets[j] - buf[i] = unwrapNode(b[0]) - buckets[j] = b[1:] - if len(b) == 1 { - buckets = append(buckets[:j], buckets[j+1:]...) - } - if len(buckets) == 0 { - break - } + // Shuffle. + for i := 0; i < len(nodes); i++ { + j := tab.rand.Intn(len(nodes)) + nodes[i], nodes[j] = nodes[j], nodes[i] } - return i + 1 + return copy(buf, nodes) } // getNode returns the node with the given ID or nil if it isn't in the table. diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 8e5fc7374b47..2adfd0dc0ec1 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -17,11 +17,14 @@ package discover import ( + "bytes" "crypto/ecdsa" "encoding/hex" + "errors" "fmt" "math/rand" "net" + "reflect" "sort" "sync" @@ -169,6 +172,28 @@ func hasDuplicates(slice []*node) bool { return false } +func checkNodesEqual(got, want []*enode.Node) error { + if reflect.DeepEqual(got, want) { + return nil + } + output := new(bytes.Buffer) + fmt.Fprintf(output, "got %d nodes:\n", len(got)) + for _, n := range got { + fmt.Fprintf(output, " %v %v\n", n.ID(), n) + } + fmt.Fprintf(output, "want %d:\n", len(want)) + for _, n := range want { + fmt.Fprintf(output, " %v %v\n", n.ID(), n) + } + return errors.New(output.String()) +} + +func sortByID(nodes []*enode.Node) { + sort.Slice(nodes, func(i, j int) bool { + return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes()) + }) +} + func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { return sort.SliceIsSorted(slice, func(i, j int) bool { return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0 diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index b2a5d85cf421..ecfc98388117 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discutil" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" @@ -202,6 +203,7 @@ type UDPv4 struct { localNode *enode.LocalNode db *enode.DB tab *Table + randomWalk *lookupWalker closeOnce sync.Once wg sync.WaitGroup @@ -270,6 +272,7 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { if t.log == nil { t.log = log.Root() } + tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log) if err != nil { return nil, err @@ -277,6 +280,8 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { t.tab = tab go tab.loop() + t.randomWalk = newLookupWalker(t.randomLookupWithCallback) + t.wg.Add(2) go t.loop() go t.readLoop(cfg.Unhandled) @@ -295,47 +300,57 @@ func (t *UDPv4) Close() { t.conn.Close() t.wg.Wait() t.tab.close() + t.randomWalk.close() }) } -// ReadRandomNodes reads random nodes from the local table. -func (t *UDPv4) ReadRandomNodes(buf []*enode.Node) int { - return t.tab.ReadRandomNodes(buf) +// RandomNodes is an iterator yielding nodes from a random walk of the DHT. +// +// All iterators share the same random walk to minimize network traffic. Discovered nodes +// are checked against the filter function and returned by the iterator only when the +// filter returns true. +func (t *UDPv4) RandomNodes(filter func(*enode.Node) bool) discutil.Iterator { + return t.randomWalk.newIterator(filter) } // LookupRandom finds random nodes in the network. -func (t *UDPv4) LookupRandom() []*enode.Node { +func (t *UDPv4) randomLookupWithCallback(cancel <-chan struct{}, callback func(*enode.Node)) { if t.tab.len() == 0 { // All nodes were dropped, refresh. The very first query will hit this // case and run the bootstrapping logic. <-t.tab.refresh() } - return t.lookupRandom() + var target encPubkey + crand.Read(target[:]) + t.lookup(target, cancel, callback) } +// LookupPubkey finds the closest nodes to the given public key. func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node { if t.tab.len() == 0 { // All nodes were dropped, refresh. The very first query will hit this // case and run the bootstrapping logic. <-t.tab.refresh() } - return unwrapNodes(t.lookup(encodePubkey(key))) + return unwrapNodes(t.lookup(encodePubkey(key), t.tab.closeReq, nil)) } +// for Table func (t *UDPv4) lookupRandom() []*enode.Node { var target encPubkey crand.Read(target[:]) - return unwrapNodes(t.lookup(target)) + return unwrapNodes(t.lookup(target, t.tab.closeReq, nil)) } +// for Table func (t *UDPv4) lookupSelf() []*enode.Node { - return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey))) + return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey), t.tab.closeReq, nil)) } // lookup performs a network search for nodes close to the given target. It approaches the // target by querying nodes that are closer to it on each iteration. The given target does // not need to be an actual node identifier. -func (t *UDPv4) lookup(targetKey encPubkey) []*node { +func (t *UDPv4) lookup(targetKey encPubkey, cancel <-chan struct{}, nodeCallback func(*enode.Node)) []*node { var ( target = enode.ID(crypto.Keccak256Hash(targetKey[:])) asked = make(map[enode.ID]bool) @@ -360,7 +375,7 @@ func (t *UDPv4) lookup(targetKey encPubkey) []*node { if !asked[n.ID()] { asked[n.ID()] = true pendingQueries++ - go t.lookupWorker(n, targetKey, reply) + go t.lookupWorker(n, targetKey, reply, nodeCallback) } } if pendingQueries == 0 { @@ -375,7 +390,7 @@ func (t *UDPv4) lookup(targetKey encPubkey) []*node { result.push(n, bucketSize) } } - case <-t.tab.closeReq: + case <-cancel: return nil // shutdown, no need to continue. } pendingQueries-- @@ -383,7 +398,7 @@ func (t *UDPv4) lookup(targetKey encPubkey) []*node { return result.entries } -func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node) { +func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node, callback func(*enode.Node)) { fails := t.db.FindFails(n.ID(), n.IP()) r, err := t.findnode(n.ID(), n.addr(), targetKey) if err == errClosed { @@ -407,7 +422,11 @@ func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node) // just remove those again during revalidation. for _, n := range r { t.tab.addSeenNode(n) + if callback != nil { + callback(unwrapNode(n)) + } } + reply <- r } diff --git a/p2p/discutil/iter.go b/p2p/discutil/iter.go new file mode 100644 index 000000000000..cebd2236fd13 --- /dev/null +++ b/p2p/discutil/iter.go @@ -0,0 +1,239 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package discutil provides node discovery utilities. +package discutil + +import ( + "sync" + "time" + + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// Iterator represents a sequence of nodes. +// +// The Next method returns the next node in the sequence. The isLive return value reports +// whether the iterator is still open. Once closed, iterators should keep returning (nil, false). +// +// Close may be called concurrently with Next and Node, and interrupts Next if it is blocked. +type Iterator interface { + Next() bool // moves to next node + Node() *enode.Node // returns current node + Close() // ends the iterator +} + +// ReadNodes reads at most n nodes from the given iterator. The return value contains no +// duplicates and no nil values. To prevent looping indefinitely for small repeating node +// sequences, this function calls NextNode at most n times. +func ReadNodes(it Iterator, n int) []*enode.Node { + seen := make(map[enode.ID]*enode.Node, n) + for i := 0; i < n && it.Next(); i++ { + // Remove duplicates, keeping the node with higher seq. + node := it.Node() + prevNode, ok := seen[node.ID()] + if ok && prevNode.Seq() > node.Seq() { + continue + } + seen[node.ID()] = node + } + result := make([]*enode.Node, 0, len(seen)) + for _, node := range seen { + result = append(result, node) + } + return result +} + +// Filter wraps an iterator such that NextNode only returns nodes for which +// the 'check' function returns true. +func Filter(it Iterator, check func(*enode.Node) bool) Iterator { + return &filterIter{it, check} +} + +type filterIter struct { + Iterator + check func(*enode.Node) bool +} + +func (f *filterIter) Next() bool { + for f.Iterator.Next() { + if f.check(f.Node()) { + return true + } + } + return false +} + +// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends +// only when Close is called. Source iterators added via AddSource are removed from the mix +// when they end. +// +// The distribution of nodes returned by NextNode is approximately fair, i.e. FairMix +// attempts to draw from all sources equally often. However, if a certain source is slow +// and doesn't return a node within the configured timeout, a node from any other source +// will be returned. +// +// It's safe to call AddSource and Close concurrently with NextNode. +type FairMix struct { + wg sync.WaitGroup + fromAny chan *enode.Node + timeout time.Duration + cur *enode.Node + + mu sync.Mutex + closed chan struct{} + sources []*mixSource + last int +} + +type mixSource struct { + it Iterator + next chan *enode.Node +} + +// NewFairMix creates a mixer. +// +// The timeout specifies how long the mixer will wait for the next fairly-chosen source +// before giving up and taking a node from any other source. A good way to set the timeout +// is deciding how long you'd want to wait for a node on average. Passing a negative +// timeout disables the mixer completely fair. +func NewFairMix(timeout time.Duration) *FairMix { + m := &FairMix{ + fromAny: make(chan *enode.Node), + closed: make(chan struct{}), + timeout: timeout, + } + return m +} + +// AddSource adds a source of nodes. +func (m *FairMix) AddSource(it Iterator) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + m.wg.Add(1) + source := &mixSource{it, make(chan *enode.Node)} + m.sources = append(m.sources, source) + go m.runSource(m.closed, source) +} + +// Close shuts down the mixer and all current sources. +// Calling this is required to release resources associated with the mixer. +func (m *FairMix) Close() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + for _, s := range m.sources { + s.it.Close() + } + close(m.closed) + m.wg.Wait() + close(m.fromAny) + m.sources = nil + m.closed = nil +} + +// NextNode returns a node from a random source. +func (m *FairMix) Next() bool { + m.cur = nil + + var timeout <-chan time.Time + if m.timeout >= 0 { + timer := time.NewTimer(m.timeout) + timeout = timer.C + defer timer.Stop() + } + for { + source := m.pickSource() + if source == nil { + return m.nextFromAny() + } + select { + case n, ok := <-source.next: + if ok { + m.cur = n + return true + } + // This source has ended. + m.deleteSource(source) + case <-timeout: + return m.nextFromAny() + } + } +} + +// Node returns the current node. +func (m *FairMix) Node() *enode.Node { + return m.cur +} + +// nextFromAny is used when there are no sources or when the 'fair' choice +// doesn't turn up a node quickly enough. +func (m *FairMix) nextFromAny() bool { + n, ok := <-m.fromAny + if ok { + m.cur = n + } + return ok +} + +// pickSource chooses the next source to read from, cycling through them in order. +func (m *FairMix) pickSource() *mixSource { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.sources) == 0 { + return nil + } + m.last = (m.last + 1) % len(m.sources) + return m.sources[m.last] +} + +// deleteSource deletes a source. +func (m *FairMix) deleteSource(s *mixSource) { + m.mu.Lock() + defer m.mu.Unlock() + + for i := range m.sources { + if m.sources[i] == s { + copy(m.sources[i:], m.sources[i+1:]) + m.sources[len(m.sources)-1] = nil + m.sources = m.sources[:len(m.sources)-1] + break + } + } +} + +// runSource reads a single source in a loop. +func (m *FairMix) runSource(closed chan struct{}, s *mixSource) { + defer m.wg.Done() + defer close(s.next) + for s.it.Next() { + n := s.it.Node() + select { + case s.next <- n: + case m.fromAny <- n: + case <-closed: + return + } + } +} diff --git a/p2p/discutil/iter_test.go b/p2p/discutil/iter_test.go new file mode 100644 index 000000000000..4d65d67a223f --- /dev/null +++ b/p2p/discutil/iter_test.go @@ -0,0 +1,275 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discutil + +import ( + "encoding/binary" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +func TestReadNodes(t *testing.T) { + nodes := ReadNodes(new(genIter), 10) + checkNodes(t, nodes, 10) +} + +// This test checks that ReadNodes terminates when reading N nodes from an iterator +// which returns less than N nodes in an endless cycle. +func TestReadNodesCycle(t *testing.T) { + iter := &callCountIter{ + Iterator: cycleNodes( + testNode(0, 0), + testNode(1, 0), + testNode(2, 0), + ), + } + nodes := ReadNodes(iter, 10) + checkNodes(t, nodes, 3) + if iter.count != 10 { + t.Fatalf("%d calls to Next, want %d", iter.count, 100) + } +} + +func checkNodes(t *testing.T, nodes []*enode.Node, wantLen int) { + if len(nodes) != wantLen { + t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen) + return + } + seen := make(map[enode.ID]bool) + for i, e := range nodes { + if e == nil { + t.Errorf("nil node at index %d", i) + return + } + if seen[e.ID()] { + t.Errorf("slice has duplicate node %v", e.ID()) + return + } + seen[e.ID()] = true + } +} + +// This test checks fairness of FairMix in the happy case where all sources return nodes +// within the context's deadline. +func TestFairMix(t *testing.T) { + for i := 0; i < 500; i++ { + testMixerFairness(t) + } +} + +func testMixerFairness(t *testing.T) { + mix := NewFairMix(1 * time.Second) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(&genIter{index: 2}) + mix.AddSource(&genIter{index: 3}) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + // Verify that the nodes slice contains an approximately equal number of nodes + // from each source. + d := idPrefixDistribution(nodes) + for _, count := range d { + if approxEqual(count, len(nodes)/3, 30) { + t.Fatalf("ID distribution is unfair: %v", d) + } + } +} + +// This test checks that FairMix falls back to an alternative source when +// the 'fair' choice doesn't return a node within the timeout. +func TestFairMixNextFromAll(t *testing.T) { + mix := NewFairMix(1 * time.Millisecond) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(cycleNodes()) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + d := idPrefixDistribution(nodes) + if len(d) > 1 || d[1] != len(nodes) { + t.Fatalf("wrong ID distribution: %v", d) + } +} + +// This test ensures FairMix works for Next with no sources. +func TestFairMixEmpty(t *testing.T) { + var ( + mix = NewFairMix(1 * time.Second) + testN = testNode(1, 1) + ch = make(chan *enode.Node) + ) + defer mix.Close() + + go func() { + mix.Next() + ch <- mix.Node() + }() + + mix.AddSource(cycleNodes(testN)) + if n := <-ch; n != testN { + t.Errorf("got wrong node: %v", n) + } +} + +// This test checks closing a source while Next runs. +func TestFairMixRemoveSource(t *testing.T) { + mix := NewFairMix(1 * time.Second) + source := cycleNodes() + source.Close() + mix.AddSource(source) + + if mix.Next() { + t.Fatal("Next should've returned false") + } + if len(mix.sources) != 0 { + t.Fatalf("have %d sources, want zero", len(mix.sources)) + } +} + +func TestFairMixClose(t *testing.T) { + for i := 0; i < 20 && !t.Failed(); i++ { + testMixerClose(t) + } +} + +func testMixerClose(t *testing.T) { + mix := NewFairMix(-1) + mix.AddSource(cycleNodes()) + mix.AddSource(cycleNodes()) + + done := make(chan struct{}) + go func() { + defer close(done) + if mix.Next() { + t.Error("Next returned true") + } + }() + // This call is supposed to make it more likely that NextNode is + // actually executing by the time we call Close. + runtime.Gosched() + + mix.Close() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("Next didn't unblock on Close") + } + + mix.Close() // shouldn't crash +} + +func idPrefixDistribution(nodes []*enode.Node) map[uint32]int { + d := make(map[uint32]int) + for _, node := range nodes { + id := node.ID() + d[binary.BigEndian.Uint32(id[:4])]++ + } + return d +} + +func approxEqual(x, y, ε int) bool { + if y > x { + x, y = y, x + } + return x-y > ε +} + +// genIter creates fake nodes with numbered IDs based on 'index' and 'gen' +type genIter struct { + node *enode.Node + index, gen uint32 +} + +func (s *genIter) Next() bool { + index := atomic.LoadUint32(&s.index) + if index == ^uint32(0) { + s.node = nil + return false + } + s.node = testNode(uint64(index)<<32|uint64(s.gen), 0) + s.gen++ + return true +} + +func (s *genIter) Node() *enode.Node { + return s.node +} + +func (s *genIter) Close() { + s.index = ^uint32(0) +} + +func testNode(id, seq uint64) *enode.Node { + var nodeID enode.ID + binary.BigEndian.PutUint64(nodeID[:], id) + r := new(enr.Record) + r.SetSeq(seq) + return enode.SignNull(r, nodeID) +} + +// cycleNodes is an interator that cycles through the given slice. +func cycleNodes(nodes ...*enode.Node) Iterator { + return &cycleIter{nodes: nodes} +} + +type cycleIter struct { + cur *enode.Node + mu sync.Mutex + index int + nodes []*enode.Node +} + +func (s *cycleIter) Next() bool { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.nodes) == 0 { + return false + } + s.cur = s.nodes[s.index] + s.index = (s.index + 1) % len(s.nodes) + return true +} + +func (s *cycleIter) Node() *enode.Node { + return s.nodes[s.index] +} + +func (s *cycleIter) Close() { + s.mu.Lock() + s.nodes = nil + s.mu.Unlock() +} + +// callCountIter counts calls to NextNode. +type callCountIter struct { + Iterator + count int +} + +func (it *callCountIter) Next() bool { + it.count++ + return it.Iterator.Next() +} diff --git a/p2p/protocol.go b/p2p/protocol.go index 9438ab8e47dd..30496709311c 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -19,6 +19,7 @@ package p2p import ( "fmt" + "github.com/ethereum/go-ethereum/p2p/discutil" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" ) @@ -54,6 +55,11 @@ type Protocol struct { // but returns nil, it is assumed that the protocol handshake is still running. PeerInfo func(id enode.ID) interface{} + // DialCandidates, if non-nil, is a way to tell Server about protocol-specific nodes + // that should be dialed. The server continuously reads nodes from the iterator and + // attempts to create connections to them. + DialCandidates discutil.Iterator + // Attributes contains protocol specific information for the node record. Attributes []enr.Entry } diff --git a/p2p/server.go b/p2p/server.go index b7340a5ea718..798d3162874a 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -35,6 +35,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/discutil" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" @@ -167,16 +168,20 @@ type Server struct { lock sync.Mutex // protects running running bool - nodedb *enode.DB - localnode *enode.LocalNode - ntab discoverTable listener net.Listener ourHandshake *protoHandshake - DiscV5 *discv5.Network loopWG sync.WaitGroup // loop, listenLoop peerFeed event.Feed log log.Logger + nodedb *enode.DB + localnode *enode.LocalNode + ntab *discover.UDPv4 + DiscV5 *discv5.Network + discmix *discutil.FairMix + + staticNodeResolver nodeResolver + // Channels into the run loop. quit chan struct{} addstatic chan *enode.Node @@ -465,7 +470,7 @@ func (srv *Server) Start() (err error) { } dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config) + dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config) srv.loopWG.Add(1) go srv.run(dialer) return nil @@ -517,6 +522,8 @@ func (srv *Server) setupLocalNode() error { } func (srv *Server) setupDiscovery() error { + srv.discmix = discutil.NewFairMix(fallbackInterval) + if srv.NoDiscovery && !srv.DiscoveryV5 { return nil } @@ -558,7 +565,10 @@ func (srv *Server) setupDiscovery() error { return err } srv.ntab = ntab + srv.discmix.AddSource(ntab.RandomNodes(nil)) + srv.staticNodeResolver = ntab } + // Discovery V5 if srv.DiscoveryV5 { var ntab *discv5.Network diff --git a/p2p/server_test.go b/p2p/server_test.go index e8bc627e1d30..9f01ffc16bf6 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -234,7 +234,6 @@ func TestServerTaskScheduling(t *testing.T) { localnode: enode.NewLocalNode(db, newkey()), nodedb: db, quit: make(chan struct{}), - ntab: fakeTable{}, running: true, log: log.New(), } @@ -282,7 +281,6 @@ func TestServerManyTasks(t *testing.T) { quit: make(chan struct{}), localnode: enode.NewLocalNode(db, newkey()), nodedb: db, - ntab: fakeTable{}, running: true, log: log.New(), }