Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shutdown notifications engine when closing a bitswap session #4658

Merged
merged 8 commits into from
Feb 13, 2018
8 changes: 4 additions & 4 deletions exchange/bitswap/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type Exportable interface {

type impl struct {
full bool
wantlist map[string]Entry
wantlist map[string]*Entry
blocks map[string]blocks.Block
}

Expand All @@ -61,7 +61,7 @@ func New(full bool) BitSwapMessage {
func newMsg(full bool) *impl {
return &impl{
blocks: make(map[string]blocks.Block),
wantlist: make(map[string]Entry),
wantlist: make(map[string]*Entry),
full: full,
}
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func (m *impl) Empty() bool {
func (m *impl) Wantlist() []Entry {
out := make([]Entry, 0, len(m.wantlist))
for _, e := range m.wantlist {
out = append(out, e)
out = append(out, *e)
}
return out
}
Expand Down Expand Up @@ -151,7 +151,7 @@ func (m *impl) addEntry(c *cid.Cid, priority int, cancel bool) {
e.Priority = priority
e.Cancel = cancel
} else {
m.wantlist[k] = Entry{
m.wantlist[k] = &Entry{
Entry: &wantlist.Entry{
Cid: c,
Priority: priority,
Expand Down
56 changes: 53 additions & 3 deletions exchange/bitswap/notifications/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package notifications

import (
"context"
"sync"

blocks "gx/ipfs/Qmej7nf81hi2x2tvjRBF3mcp74sQyuDH4VMYDGd1YtXjb2/go-block-format"

Expand All @@ -18,18 +19,43 @@ type PubSub interface {
}

func New() PubSub {
return &impl{*pubsub.New(bufferSize)}
return &impl{
wrapped: *pubsub.New(bufferSize),
cancel: make(chan struct{}),
}
}

type impl struct {
wrapped pubsub.PubSub

// These two fields make up a shutdown "lock".
// We need them as calling, e.g., `Unsubscribe` after calling `Shutdown`
// blocks forever and fixing this in pubsub would be rather invasive.
cancel chan struct{}
wg sync.WaitGroup
}

func (ps *impl) Publish(block blocks.Block) {
ps.wg.Add(1)
defer ps.wg.Done()

select {
case <-ps.cancel:
// Already shutdown, bail.
return
default:
}

ps.wrapped.Pub(block, block.Cid().KeyString())
}

// Not safe to call more than once.
func (ps *impl) Shutdown() {
// Interrupt in-progress subscriptions.
close(ps.cancel)
// Wait for them to finish.
ps.wg.Wait()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will wait for all active wants to be cancelled, which happens if the caller closes the session, right? I would like to see a test around this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it'll just wait for all unsubscribes to finish (which should happen immediately after we close the ps.cancel channel).

However, I have added a test to ensure that shutting down the PubSub while a subscription is active works (and doesn't block as it did before). Is that what you're looking for?

// shutdown the pubsub.
ps.wrapped.Shutdown()
}

Expand All @@ -44,12 +70,34 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B
close(blocksCh)
return blocksCh
}

// prevent shutdown
ps.wg.Add(1)

// check if shutdown *after* preventing shutdowns.
select {
case <-ps.cancel:
// abort, allow shutdown to continue.
ps.wg.Done()
close(blocksCh)
return blocksCh
default:
}

ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...)
go func() {
defer close(blocksCh)
defer ps.wrapped.Unsub(valuesCh) // with a len(keys) buffer, this is an optimization
defer func() {
ps.wrapped.Unsub(valuesCh)
close(blocksCh)

// Unblock shutdown.
ps.wg.Done()
}()

for {
select {
case <-ps.cancel:
return
case <-ctx.Done():
return
case val, ok := <-valuesCh:
Expand All @@ -61,6 +109,8 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B
return
}
select {
case <-ps.cancel:
return
case <-ctx.Done():
return
case blocksCh <- block: // continue
Expand Down
19 changes: 19 additions & 0 deletions exchange/bitswap/notifications/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,25 @@ func TestDuplicateSubscribe(t *testing.T) {
assertBlocksEqual(t, e1, r2)
}

func TestShutdownBeforeUnsubscribe(t *testing.T) {
e1 := blocks.NewBlock([]byte("1"))

n := New()
ctx, cancel := context.WithCancel(context.Background())
ch := n.Subscribe(ctx, e1.Cid()) // no keys provided
n.Shutdown()
cancel()

select {
case _, ok := <-ch:
if ok {
t.Fatal("channel should have been closed")
}
default:
t.Fatal("channel should have been closed")
}
}

func TestSubscribeIsANoopWhenCalledWithNoKeys(t *testing.T) {
n := New()
defer n.Shutdown()
Expand Down
12 changes: 11 additions & 1 deletion exchange/bitswap/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ func (bs *Bitswap) NewSession(ctx context.Context) *Session {
}

func (bs *Bitswap) removeSession(s *Session) {
s.notif.Shutdown()

live := make([]*cid.Cid, 0, len(s.liveWants))
for c := range s.liveWants {
cs, _ := cid.Cast([]byte(c))
live = append(live, cs)
}
bs.CancelWants(live, s.id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if another session wants a cid which is cancelled here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wants are tracked per session (note that we pass in the session ID)


bs.sessLk.Lock()
defer bs.sessLk.Unlock()
for i := 0; i < len(bs.sessions); i++ {
Expand Down Expand Up @@ -270,8 +279,9 @@ func (s *Session) receiveBlock(ctx context.Context, blk blocks.Block) {
}

func (s *Session) wantBlocks(ctx context.Context, ks []*cid.Cid) {
now := time.Now()
for _, c := range ks {
s.liveWants[c.KeyString()] = time.Now()
s.liveWants[c.KeyString()] = now
}
s.bs.wm.WantBlocks(ctx, ks, s.activePeersArr, s.id)
}
Expand Down
33 changes: 33 additions & 0 deletions exchange/bitswap/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,36 @@ func TestMultipleSessions(t *testing.T) {
}
_ = blkch
}

func TestWantlistClearsOnCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vnet := getVirtualNetwork()
sesgen := NewTestSessionGenerator(vnet)
defer sesgen.Close()
bgen := blocksutil.NewBlockGenerator()

blks := bgen.Blocks(10)
var cids []*cid.Cid
for _, blk := range blks {
cids = append(cids, blk.Cid())
}

inst := sesgen.Instances(1)

a := inst[0]

ctx1, cancel1 := context.WithCancel(ctx)
ses := a.Exchange.NewSession(ctx1)

_, err := ses.GetBlocks(ctx, cids)
if err != nil {
t.Fatal(err)
}
cancel1()

if len(a.Exchange.GetWantlist()) > 0 {
t.Fatal("expected empty wantlist")
}
}
61 changes: 55 additions & 6 deletions exchange/bitswap/testnet/virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"time"

bsmsg "github.com/ipfs/go-ipfs/exchange/bitswap/message"
bsnet "github.com/ipfs/go-ipfs/exchange/bitswap/network"
Expand All @@ -22,7 +23,7 @@ var log = logging.Logger("bstestnet")

func VirtualNetwork(rs mockrouting.Server, d delay.D) Network {
return &network{
clients: make(map[peer.ID]bsnet.Receiver),
clients: make(map[peer.ID]*receiverQueue),
delay: d,
routingserver: rs,
conns: make(map[string]struct{}),
Expand All @@ -31,12 +32,28 @@ func VirtualNetwork(rs mockrouting.Server, d delay.D) Network {

type network struct {
mu sync.Mutex
clients map[peer.ID]bsnet.Receiver
clients map[peer.ID]*receiverQueue
routingserver mockrouting.Server
delay delay.D
conns map[string]struct{}
}

type message struct {
from peer.ID
msg bsmsg.BitSwapMessage
shouldSend time.Time
}

// receiverQueue queues up a set of messages to be sent, and sends them *in
// order* with their delays respected as much as sending them in order allows
// for
type receiverQueue struct {
receiver bsnet.Receiver
queue []*message
active bool
lk sync.Mutex
}

func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork {
n.mu.Lock()
defer n.mu.Unlock()
Expand All @@ -46,7 +63,7 @@ func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork {
network: n,
routing: n.routingserver.Client(p),
}
n.clients[p.ID()] = client
n.clients[p.ID()] = &receiverQueue{receiver: client}
return client
}

Expand All @@ -64,7 +81,7 @@ func (n *network) SendMessage(
ctx context.Context,
from peer.ID,
to peer.ID,
message bsmsg.BitSwapMessage) error {
mes bsmsg.BitSwapMessage) error {

n.mu.Lock()
defer n.mu.Unlock()
Expand All @@ -77,7 +94,12 @@ func (n *network) SendMessage(
// nb: terminate the context since the context wouldn't actually be passed
// over the network in a real scenario

go n.deliver(receiver, from, message)
msg := &message{
from: from,
msg: mes,
shouldSend: time.Now().Add(n.delay.Get()),
}
receiver.enqueue(msg)

return nil
}
Expand Down Expand Up @@ -191,11 +213,38 @@ func (nc *networkClient) ConnectTo(_ context.Context, p peer.ID) error {

// TODO: add handling for disconnects

otherClient.PeerConnected(nc.local)
otherClient.receiver.PeerConnected(nc.local)
nc.Receiver.PeerConnected(p)
return nil
}

func (rq *receiverQueue) enqueue(m *message) {
rq.lk.Lock()
defer rq.lk.Unlock()
rq.queue = append(rq.queue, m)
if !rq.active {
rq.active = true
go rq.process()
}
}

func (rq *receiverQueue) process() {
for {
rq.lk.Lock()
if len(rq.queue) == 0 {
rq.active = false
rq.lk.Unlock()
return
}
m := rq.queue[0]
rq.queue = rq.queue[1:]
rq.lk.Unlock()

time.Sleep(time.Until(m.shouldSend))
rq.receiver.ReceiveMessage(context.TODO(), m.from, m.msg)
}
}

func tagForPeers(a, b peer.ID) string {
if a < b {
return string(a + b)
Expand Down