Skip to content

Commit

Permalink
Dont double count data sent (#185)
Browse files Browse the repository at this point in the history
* fix: dont double count data sent

* fix: stricter asserts in channels tests
  • Loading branch information
dirkmc authored Apr 5, 2021
1 parent d7ca900 commit 2ceebfb
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 56 deletions.
32 changes: 19 additions & 13 deletions channels/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,21 @@ func (c *Channels) CompleteCleanupOnRestart(chid datatransfer.ChannelID) error {
return c.send(chid, datatransfer.CompleteCleanupOnRestart)
}

func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error {
// Returns true if this is the first time the block has been sent
func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) {
return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, k, delta)
}

func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error {
// Returns true if this is the first time the block has been queued
func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) {
return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, k, delta)
}

func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error {
// Returns true if this is the first time the block has been received
func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) {
err := c.cidLists.AppendList(chid, k)
if err != nil {
return err
return false, err
}

return c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, k, delta)
Expand Down Expand Up @@ -366,31 +369,34 @@ func (c *Channels) removeSeenCIDCaches(chid datatransfer.ChannelID) error {
return nil
}

// onProgress fires an event indicating progress has been made in
// queuing / sending / receiving blocks.
// These events are fired only for new blocks (not for example if
// a block is resent)
func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64) error {
// fireProgressEvent fires
// - an event for queuing / sending / receiving blocks
// - a corresponding "progress" event if the block has not been seen before
// For example, if a block is being sent for the first time, the method will
// fire both DataSent AND DataSentProgress.
// If a block is resent, the method will fire DataSent but not DataSentProgress.
// Returns true if the block is new (both the event and a progress event were fired).
func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64) (bool, error) {
if err := c.checkChannelExists(chid, evt); err != nil {
return err
return false, err
}

// Check if the block has already been seen
sid := cidsets.SetID(chid.String() + "/" + datatransfer.Events[evt])
seen, err := c.seenCIDs.InsertSetCID(sid, k)
if err != nil {
return err
return false, err
}

// If the block has not been seen before, fire the progress event
if !seen {
if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil {
return err
return false, err
}
}

// Fire the regular event
return c.stateMachines.Send(chid, evt)
return !seen, c.stateMachines.Send(chid, evt)
}

func (c *Channels) send(chid datatransfer.ChannelID, code datatransfer.EventCode, args ...interface{}) error {
Expand Down
21 changes: 14 additions & 7 deletions channels/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,46 +145,53 @@ func TestChannels(t *testing.T) {
require.Equal(t, uint64(0), state.Sent())
require.Empty(t, state.ReceivedCids())

err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
isNew, err := channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
require.NoError(t, err)
_ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress)
require.True(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataReceived)
require.Equal(t, uint64(50), state.Received())
require.Equal(t, uint64(0), state.Sent())
require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids())

err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100)
isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100)
require.NoError(t, err)
_ = checkEvent(ctx, t, received, datatransfer.DataSentProgress)
require.True(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataSent)
require.Equal(t, uint64(50), state.Received())
require.Equal(t, uint64(100), state.Sent())
require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids())

// errors if channel does not exist
err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
require.True(t, xerrors.As(err, new(*channels.ErrNotFound)))
err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
require.False(t, isNew)
isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
require.True(t, xerrors.As(err, new(*channels.ErrNotFound)))
require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids())
require.False(t, isNew)

err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50)
isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50)
require.NoError(t, err)
_ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress)
require.True(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataReceived)
require.Equal(t, uint64(100), state.Received())
require.Equal(t, uint64(100), state.Sent())
require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids())

err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25)
isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25)
require.NoError(t, err)
require.False(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataSent)
require.Equal(t, uint64(100), state.Received())
require.Equal(t, uint64(100), state.Sent())
require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids())

err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
require.NoError(t, err)
require.False(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataReceived)
require.Equal(t, uint64(100), state.Received())
require.Equal(t, uint64(100), state.Sent())
Expand Down
110 changes: 74 additions & 36 deletions impl/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,63 +28,101 @@ func (m *manager) OnChannelOpened(chid datatransfer.ChannelID) error {
return nil
}

// OnDataReceived is called when the transport layer reports that it has
// received some data from the sender.
// It fires an event on the channel, updating the sum of received data and
// calls revalidators so they can pause / resume the channel or send a
// message over the transport.
func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64) error {
err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size)
isNew, err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size)
if err != nil {
return err
}

if chid.Initiator != m.peerID {
var result datatransfer.VoucherResult
var err error
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPushDataReceived(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
msg, err := m.processRevalidationResult(chid, result, err)
if msg != nil {
if err := m.dataTransferNetwork.SendMessage(context.TODO(), chid.Initiator, msg); err != nil {
return err
}
// If this block has already been received on the channel, take no further
// action (this can happen when the data-transfer channel is restarted)
if !isNew {
return nil
}

// If this node initiated the data transfer, there's nothing more to do
if chid.Initiator == m.peerID {
return nil
}

// Check each revalidator to see if they want to pause / resume, or send
// a message over the transport
var result datatransfer.VoucherResult
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPushDataReceived(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
msg, err := m.processRevalidationResult(chid, result, err)
if msg != nil {
if err := m.dataTransferNetwork.SendMessage(context.TODO(), chid.Initiator, msg); err != nil {
return err
}
return err
}
return err
}

return nil
}

// OnDataQueued is called when the transport layer reports that it has queued
// up some data to be sent to the requester.
// It fires an event on the channel, updating the sum of queued data and calls
// revalidators so they can pause / resume or send a message over the transport.
func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64) (datatransfer.Message, error) {
if err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size); err != nil {
// The transport layer reports that some data has been queued up to be sent
// to the requester, so fire a DataQueued event on the channels state
// machine.
isNew, err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size)
if err != nil {
return nil, err
}
if chid.Initiator != m.peerID {
var result datatransfer.VoucherResult
var err error
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPullDataSent(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
return m.processRevalidationResult(chid, result, err)

// If this block has already been queued on the channel, take no further
// action (this can happen when the data-transfer channel is restarted)
if !isNew {
return nil, nil
}

// If this node initiated the data transfer, there's nothing more to do
if chid.Initiator == m.peerID {
return nil, nil
}

// Check each revalidator to see if they want to pause / resume, or send
// a message over the transport.
// For example if the data-sender is waiting for the receiver to pay for
// data they may pause the data-transfer.
var result datatransfer.VoucherResult
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPullDataSent(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
return m.processRevalidationResult(chid, result, err)
}

return nil, nil
}

func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64) error {
return m.channels.DataSent(chid, link.(cidlink.Link).Cid, size)
_, err := m.channels.DataSent(chid, link.(cidlink.Link).Cid, size)
return err
}

func (m *manager) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) {
Expand Down
73 changes: 73 additions & 0 deletions impl/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
)

const loremFile = "lorem.txt"
const loremFileTransferBytes = 20439

// nil means use the default protocols
// tests data transfer for the following protocol combinations:
Expand Down Expand Up @@ -520,6 +521,66 @@ func (dc *disconnectCoordinator) onDisconnect() {
close(dc.disconnected)
}

type restartRevalidator struct {
*testutil.StubbedRevalidator
pullDataSent map[datatransfer.ChannelID][]uint64
pushDataRcvd map[datatransfer.ChannelID][]uint64
}

func newRestartRevalidator() *restartRevalidator {
return &restartRevalidator{
StubbedRevalidator: testutil.NewStubbedRevalidator(),
pullDataSent: make(map[datatransfer.ChannelID][]uint64),
pushDataRcvd: make(map[datatransfer.ChannelID][]uint64),
}
}

func (r *restartRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (bool, datatransfer.VoucherResult, error) {
chSent, ok := r.pullDataSent[chid]
if !ok {
chSent = []uint64{}
}
chSent = append(chSent, additionalBytesSent)
r.pullDataSent[chid] = chSent

return true, nil, nil
}

func (r *restartRevalidator) pullDataSum(chid datatransfer.ChannelID) uint64 {
pullDataSent, ok := r.pullDataSent[chid]
var total uint64
if !ok {
return total
}
for _, sent := range pullDataSent {
total += sent
}
return total
}

func (r *restartRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (bool, datatransfer.VoucherResult, error) {
chRcvd, ok := r.pushDataRcvd[chid]
if !ok {
chRcvd = []uint64{}
}
chRcvd = append(chRcvd, additionalBytesReceived)
r.pushDataRcvd[chid] = chRcvd

return true, nil, nil
}

func (r *restartRevalidator) pushDataSum(chid datatransfer.ChannelID) uint64 {
pushDataRcvd, ok := r.pushDataRcvd[chid]
var total uint64
if !ok {
return total
}
for _, rcvd := range pushDataRcvd {
total += rcvd
}
return total
}

// TestAutoRestart tests that if the connection for a push or pull request
// goes down, it will automatically restart (given the right config options)
func TestAutoRestart(t *testing.T) {
Expand Down Expand Up @@ -714,6 +775,10 @@ func TestAutoRestart(t *testing.T) {
require.NoError(t, initiator.RegisterVoucherType(&testutil.FakeDTType{}, sv))
require.NoError(t, responder.RegisterVoucherType(&testutil.FakeDTType{}, sv))

// Register a revalidator that records calls to OnPullDataSent and OnPushDataReceived
srv := newRestartRevalidator()
require.NoError(t, responder.RegisterRevalidator(testutil.NewFakeDTType(), srv))

// If the test case needs to subscribe to response events, provide
// the test case with the responder
if tc.registerResponder != nil {
Expand Down Expand Up @@ -795,6 +860,14 @@ func TestAutoRestart(t *testing.T) {
}
})()

// Verify that the total amount of data sent / received that was
// reported to the revalidator is correct
if tc.isPush {
require.EqualValues(t, loremFileTransferBytes, srv.pushDataSum(chid))
} else {
require.EqualValues(t, loremFileTransferBytes, srv.pullDataSum(chid))
}

// Verify that the file was transferred to the destination node
testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes)
})
Expand Down

0 comments on commit 2ceebfb

Please sign in to comment.