diff --git a/internal/decision/engine.go b/internal/decision/engine.go index 81ef9b9e..49063bd5 100644 --- a/internal/decision/engine.go +++ b/internal/decision/engine.go @@ -334,8 +334,8 @@ func (e *Engine) onPeerRemoved(p peer.ID) { e.peerTagger.UntagPeer(p, e.tagQueued) } -// WantlistForPeer returns the currently understood want list for a given peer -func (e *Engine) WantlistForPeer(p peer.ID) (out []wl.Entry) { +// WantlistForPeer returns the list of keys that the given peer has asked for +func (e *Engine) WantlistForPeer(p peer.ID) []wl.Entry { partner := e.findOrCreate(p) partner.lk.Lock() @@ -343,7 +343,8 @@ func (e *Engine) WantlistForPeer(p peer.ID) (out []wl.Entry) { partner.lk.Unlock() wl.SortEntries(entries) - return + + return entries } // LedgerForPeer returns aggregated data about blocks swapped and communication diff --git a/internal/decision/engine_test.go b/internal/decision/engine_test.go index bdfa9362..cf000d96 100644 --- a/internal/decision/engine_test.go +++ b/internal/decision/engine_test.go @@ -981,6 +981,40 @@ func TestSendDontHave(t *testing.T) { } } +func TestWantlistForPeer(t *testing.T) { + bs := blockstore.NewBlockstore(dssync.MutexWrap(ds.NewMapDatastore())) + partner := libp2ptest.RandPeerIDFatal(t) + otherPeer := libp2ptest.RandPeerIDFatal(t) + + e := newEngine(context.Background(), bs, &fakePeerTagger{}, "localhost", 0, shortTerm, nil) + e.StartWorkers(context.Background(), process.WithTeardown(func() error { return nil })) + + blks := testutil.GenerateBlocksOfSize(4, 8*1024) + msg := message.New(false) + msg.AddEntry(blks[0].Cid(), 2, pb.Message_Wantlist_Have, false) + msg.AddEntry(blks[1].Cid(), 3, pb.Message_Wantlist_Have, false) + e.MessageReceived(context.Background(), partner, msg) + + msg2 := message.New(false) + msg2.AddEntry(blks[2].Cid(), 1, pb.Message_Wantlist_Block, false) + msg2.AddEntry(blks[3].Cid(), 4, pb.Message_Wantlist_Block, false) + e.MessageReceived(context.Background(), partner, msg2) + + entries := e.WantlistForPeer(otherPeer) + if len(entries) != 0 { + t.Fatal("expected wantlist to contain no wants for other peer") + } + + entries = e.WantlistForPeer(partner) + if len(entries) != 4 { + t.Fatal("expected wantlist to contain all wants from parter") + } + if entries[0].Priority != 4 || entries[1].Priority != 3 || entries[2].Priority != 2 || entries[3].Priority != 1 { + t.Fatal("expected wantlist to be sorted") + } + +} + func TestTaggingPeers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel()