diff --git a/requestmanager/requestmanager.go b/requestmanager/requestmanager.go index cbd33713..02497c7d 100644 --- a/requestmanager/requestmanager.go +++ b/requestmanager/requestmanager.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "github.com/hannahhoward/go-pubsub" + "golang.org/x/xerrors" "sync/atomic" blocks "github.com/ipfs/go-block-format" @@ -70,6 +72,7 @@ type RequestManager struct { peerHandler PeerHandler rc *responseCollector asyncLoader AsyncLoader + disconnectNotif *pubsub.PubSub // dont touch out side of run loop nextRequestID graphsync.RequestID inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus @@ -111,6 +114,7 @@ func New(ctx context.Context, ctx: ctx, cancel: cancel, asyncLoader: asyncLoader, + disconnectNotif: pubsub.New(disconnectDispatcher), rc: newResponseCollector(ctx), messages: make(chan requestManagerMessage, 16), inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus), @@ -128,6 +132,7 @@ func (rm *RequestManager) SetDelegate(peerHandler PeerHandler) { type inProgressRequest struct { requestID graphsync.RequestID + request gsmsg.GraphSyncRequest incoming chan graphsync.ResponseProgress incomingError chan error } @@ -166,6 +171,11 @@ func (rm *RequestManager) SendRequest(ctx context.Context, case receivedInProgressRequest = <-inProgressRequestChan: } + // If the connection to the peer is disconnected, fire an error + unsub := rm.listenForDisconnect(p, func(neterr error) { + rm.networkErrorListeners.NotifyNetworkErrorListeners(p, receivedInProgressRequest.request, neterr) + }) + return rm.rc.collectResponses(ctx, receivedInProgressRequest.incoming, receivedInProgressRequest.incomingError, @@ -173,7 +183,34 @@ func (rm *RequestManager) SendRequest(ctx context.Context, rm.cancelRequest(receivedInProgressRequest.requestID, receivedInProgressRequest.incoming, receivedInProgressRequest.incomingError) - }) + }, + // Once the request has completed, stop listening for disconnect events + unsub, + ) +} + +// Dispatch the Disconnect event to subscribers +func disconnectDispatcher(p pubsub.Event, subscriberFn pubsub.SubscriberFn) error { + listener := subscriberFn.(func(peer.ID)) + listener(p.(peer.ID)) + return nil +} + +// Listen for the Disconnect event for the given peer +func (rm *RequestManager) listenForDisconnect(p peer.ID, onDisconnect func(neterr error)) func() { + // Subscribe to Disconnect notifications + return rm.disconnectNotif.Subscribe(func(evtPeer peer.ID) { + // If the peer is the one we're interested in, call the listener + if evtPeer == p { + onDisconnect(xerrors.Errorf("disconnected from peer %s", p)) + } + }) +} + +// Disconnected is called when a peer disconnects +func (rm *RequestManager) Disconnected(p peer.ID) { + // Notify any listeners that a peer has disconnected + rm.disconnectNotif.Publish(p) } func (rm *RequestManager) emptyResponse() (chan graphsync.ResponseProgress, chan error) { @@ -311,17 +348,19 @@ type terminateRequestMessage struct { requestID graphsync.RequestID } -func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (chan graphsync.ResponseProgress, chan error) { +func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (gsmsg.GraphSyncRequest, chan graphsync.ResponseProgress, chan error) { request, hooksResult, err := rm.validateRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions) if err != nil { - return rm.singleErrorResponse(err) + rp, err := rm.singleErrorResponse(err) + return request, rp, err } doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs) var doNotSendCids *cid.Set if has { doNotSendCids, err = cidset.DecodeCidSet(doNotSendCidsData) if err != nil { - return rm.singleErrorResponse(err) + rp, err := rm.singleErrorResponse(err) + return request, rp, err } } else { doNotSendCids = cid.NewSet() @@ -355,14 +394,14 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re ResumeMessages: resumeMessages, PauseMessages: pauseMessages, }) - return incoming, incomingError + return request, incoming, incomingError } func (nrm *newRequestMessage) handle(rm *RequestManager) { var ipr inProgressRequest ipr.requestID = rm.nextRequestID rm.nextRequestID++ - ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm) + ipr.request, ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm) select { case nrm.inProgressRequestChan <- ipr: diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index 31b248aa..6c514a43 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -352,6 +352,42 @@ func TestRequestReturnsMissingBlocks(t *testing.T) { require.NotEqual(t, len(errs), 0, "did not send errors") } +func TestDisconnectNotification(t *testing.T) { + ctx := context.Background() + td := newTestData(ctx, t) + requestCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + peers := testutil.GeneratePeers(2) + + // Listen for network errors + networkErrors := make(chan peer.ID, 1) + td.networkErrorListeners.Register(func(p peer.ID, request graphsync.RequestData, err error) { + networkErrors <- p + }) + + // Send a request to the target peer + targetPeer := peers[0] + td.requestManager.SendRequest(requestCtx, targetPeer, td.blockChain.TipLink, td.blockChain.Selector()) + + // Disconnect a random peer, should not fire any events + randomPeer := peers[1] + td.requestManager.Disconnected(randomPeer) + select { + case <-networkErrors: + t.Fatal("should not fire network error when unrelated peer disconnects") + default: + } + + // Disconnect the target peer, should fire a network error + td.requestManager.Disconnected(targetPeer) + select { + case p:= <-networkErrors: + require.Equal(t, p, targetPeer) + default: + t.Fatal("should fire network error when peer disconnects") + } +} + func TestEncodingExtensions(t *testing.T) { ctx := context.Background() td := newTestData(ctx, t) diff --git a/requestmanager/responsecollector.go b/requestmanager/responsecollector.go index 766b9fed..062ca47c 100644 --- a/requestmanager/responsecollector.go +++ b/requestmanager/responsecollector.go @@ -18,7 +18,9 @@ func (rc *responseCollector) collectResponses( requestCtx context.Context, incomingResponses <-chan graphsync.ResponseProgress, incomingErrors <-chan error, - cancelRequest func()) (<-chan graphsync.ResponseProgress, <-chan error) { + cancelRequest func(), + onComplete func(), +) (<-chan graphsync.ResponseProgress, <-chan error) { returnedResponses := make(chan graphsync.ResponseProgress) returnedErrors := make(chan error) @@ -26,6 +28,7 @@ func (rc *responseCollector) collectResponses( go func() { var receivedResponses []graphsync.ResponseProgress defer close(returnedResponses) + defer onComplete() outgoingResponses := func() chan<- graphsync.ResponseProgress { if len(receivedResponses) == 0 { return nil diff --git a/requestmanager/responsecollector_test.go b/requestmanager/responsecollector_test.go index c543157e..51d1af7e 100644 --- a/requestmanager/responsecollector_test.go +++ b/requestmanager/responsecollector_test.go @@ -26,7 +26,7 @@ func TestBufferingResponseProgress(t *testing.T) { cancelRequest := func() {} outgoingResponses, outgoingErrors := rc.collectResponses( - requestCtx, incomingResponses, incomingErrors, cancelRequest) + requestCtx, incomingResponses, incomingErrors, cancelRequest, func(){}) blockStore := make(map[ipld.Link][]byte) loader, storer := testutil.NewTestStore(blockStore)