diff --git a/mempool/cat/cache.go b/mempool/cat/cache.go index 75b4db452b..95f9393c8f 100644 --- a/mempool/cat/cache.go +++ b/mempool/cat/cache.go @@ -130,19 +130,6 @@ func (s *SeenTxSet) Add(txKey types.TxKey, peer uint16) { } } -func (s *SeenTxSet) Pop(txKey types.TxKey) uint16 { - s.mtx.Lock() - defer s.mtx.Unlock() - seenSet, exists := s.set[txKey] - if exists { - for peer := range seenSet.peers { - delete(seenSet.peers, peer) - return peer - } - } - return 0 -} - func (s *SeenTxSet) RemoveKey(txKey types.TxKey) { s.mtx.Lock() defer s.mtx.Unlock() @@ -162,6 +149,17 @@ func (s *SeenTxSet) Remove(txKey types.TxKey, peer uint16) { } } +func (s *SeenTxSet) RemovePeer(peer uint16) { + s.mtx.Lock() + defer s.mtx.Unlock() + for key, seenSet := range s.set { + delete(seenSet.peers, peer) + if len(seenSet.peers) == 0 { + delete(s.set, key) + } + } +} + func (s *SeenTxSet) Prune(limit time.Time) { s.mtx.Lock() defer s.mtx.Unlock() diff --git a/mempool/cat/cache_test.go b/mempool/cat/cache_test.go index 5c7f6453bf..539670ff0a 100644 --- a/mempool/cat/cache_test.go +++ b/mempool/cat/cache_test.go @@ -22,7 +22,7 @@ func TestSeenTxSet(t *testing.T) { ) seenSet := NewSeenTxSet() - require.Zero(t, seenSet.Pop(tx1Key)) + require.Nil(t, seenSet.Get(tx1Key)) seenSet.Add(tx1Key, peer1) seenSet.Add(tx1Key, peer1) @@ -36,8 +36,8 @@ func TestSeenTxSet(t *testing.T) { require.Equal(t, 3, seenSet.Len()) seenSet.RemoveKey(tx2Key) require.Equal(t, 2, seenSet.Len()) - require.Zero(t, seenSet.Pop(tx2Key)) - require.Equal(t, peer1, seenSet.Pop(tx3Key)) + require.Nil(t, seenSet.Get(tx2Key)) + require.True(t, seenSet.Has(tx3Key, peer1)) } func TestLRUTxCacheRemove(t *testing.T) { diff --git a/mempool/cat/reactor.go b/mempool/cat/reactor.go index 54e0bbe999..abfdfd56d9 100644 --- a/mempool/cat/reactor.go +++ b/mempool/cat/reactor.go @@ -189,6 +189,9 @@ func (memR *Reactor) InitPeer(peer p2p.Peer) p2p.Peer { // peer it will find a new peer to rerequest the same transactions. func (memR *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) { peerID := memR.ids.Reclaim(peer.ID()) + // clear all memory of seen txs by that peer + memR.mempool.seenByPeersSet.RemovePeer(peerID) + // remove and rerequest all pending outbound requests to that peer since we know // we won't receive any responses from them. outboundRequests := memR.requests.ClearAllRequestsFrom(peerID) diff --git a/mempool/cat/reactor_test.go b/mempool/cat/reactor_test.go index 15d67bfad5..2a9b8bd655 100644 --- a/mempool/cat/reactor_test.go +++ b/mempool/cat/reactor_test.go @@ -157,6 +157,50 @@ func TestReactorBroadcastsSeenTxAfterReceivingTx(t *testing.T) { peers[1].AssertExpectations(t) } +func TestRemovePeerRequestFromOtherPeer(t *testing.T) { + reactor, _ := setupReactor(t) + + tx := newDefaultTx("hello") + key := tx.Key() + peers := genPeers(2) + reactor.InitPeer(peers[0]) + reactor.InitPeer(peers[1]) + + seenMsg := &protomem.SeenTx{TxKey: key[:]} + + wantEnv := p2p.Envelope{ + Message: &protomem.Message{ + Sum: &protomem.Message_WantTx{WantTx: &protomem.WantTx{TxKey: key[:]}}, + }, + ChannelID: MempoolStateChannel, + } + peers[0].On("SendEnvelope", wantEnv).Return(true) + peers[1].On("SendEnvelope", wantEnv).Return(true) + + reactor.ReceiveEnvelope(p2p.Envelope{ + Src: peers[0], + Message: seenMsg, + ChannelID: MempoolStateChannel, + }) + time.Sleep(100 * time.Millisecond) + reactor.ReceiveEnvelope(p2p.Envelope{ + Src: peers[1], + Message: seenMsg, + ChannelID: MempoolStateChannel, + }) + + reactor.RemovePeer(peers[0], "test") + + peers[0].AssertExpectations(t) + peers[1].AssertExpectations(t) + + require.True(t, reactor.mempool.seenByPeersSet.Has(key, 2)) + // we should have automatically sent another request out for peer 2 + require.EqualValues(t, 2, reactor.requests.ForTx(key)) + require.True(t, reactor.requests.Has(2, key)) + require.False(t, reactor.mempool.seenByPeersSet.Has(key, 1)) +} + func TestMempoolVectors(t *testing.T) { testCases := []struct { testName string diff --git a/mempool/cat/requests.go b/mempool/cat/requests.go index 5fdb344a87..8d0b78778f 100644 --- a/mempool/cat/requests.go +++ b/mempool/cat/requests.go @@ -113,8 +113,9 @@ func (r *requestScheduler) ClearAllRequestsFrom(peer uint16) requestSet { if !ok { return requestSet{} } - for _, timer := range requests { + for tx, timer := range requests { timer.Stop() + delete(r.requestsByTx, tx) } delete(r.requestsByPeer, peer) return requests