Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/events/complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@ func (u *UploadCompleter) CheckUploads(ctx context.Context) error {

log.Debugf("upload has %d parts", len(parts))

if err := u.cfg.Uploader.CompleteUpload(ctx, upload, parts); err != nil {
return trace.Wrap(err, "completing upload")
if err := u.cfg.Uploader.CompleteUpload(ctx, upload, parts); trace.IsNotFound(err) {
log.WithError(err).Debug("Upload not found, moving on to next upload")
continue
} else if err != nil {
return trace.Wrap(err)
}
log.Debug("Completed upload")
completed++
Expand Down
72 changes: 72 additions & 0 deletions lib/events/complete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package events_test

import (
"context"
"fmt"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -149,6 +150,77 @@ func TestUploadCompleterEmitsSessionEnd(t *testing.T) {
}
}

func TestCheckUploadsContinuesOnError(t *testing.T) {
clock := clockwork.NewFakeClock()
expires := clock.Now().Add(time.Hour * 1)

sessionTrackers := []types.SessionTracker{
&types.SessionTrackerV1{
Spec: types.SessionTrackerSpecV1{
SessionID: string(session.NewID()),
},
ResourceHeader: types.ResourceHeader{
Metadata: types.Metadata{
Expires: &expires,
},
},
},
&types.SessionTrackerV1{
Spec: types.SessionTrackerSpecV1{
SessionID: string(session.NewID()),
},
ResourceHeader: types.ResourceHeader{
Metadata: types.Metadata{
Expires: &expires,
},
},
},
}

sessionTrackerService := &mockSessionTrackerService{
clock: clock,
trackers: sessionTrackers,
}

var completedUploads []session.ID
uploader := &eventstest.MockUploader{
MockCompleteUpload: func(ctx context.Context, upload events.StreamUpload, parts []events.StreamPart) error {
// simulate a not found error on the first complete upload
if upload.SessionID == session.ID(sessionTrackers[0].GetSessionID()) {
return trace.NotFound("no such upload %v", sessionTrackers[0].GetSessionID())
}

completedUploads = append(completedUploads, upload.SessionID)
return nil
},
MockListUploads: func(ctx context.Context) ([]events.StreamUpload, error) {
var result []events.StreamUpload
for i, sess := range sessionTrackers {
result = append(result, events.StreamUpload{
ID: fmt.Sprintf("upload-%v", i),
SessionID: session.ID(sess.GetSessionID()),
Initiated: clock.Now(),
})
}
return result, nil
},
}

uc, err := events.NewUploadCompleter(events.UploadCompleterConfig{
Uploader: uploader,
AuditLog: &eventstest.MockAuditLog{},
SessionTracker: sessionTrackerService,
Clock: clock,
ClusterName: "teleport-cluster",
})
require.NoError(t, err)

// verify that the 2nd upload completed even though the first one failed
clock.Advance(1 * time.Hour)
uc.CheckUploads(context.Background())
require.ElementsMatch(t, completedUploads, []session.ID{session.ID(sessionTrackers[1].GetSessionID())})
}

type mockSessionTrackerService struct {
clock clockwork.Clock
trackers []types.SessionTracker
Expand Down
53 changes: 53 additions & 0 deletions lib/events/eventstest/uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,56 @@ func (m *MemoryUploader) GetUploadMetadata(sid session.ID) events.UploadMetadata
func (m *MemoryUploader) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error {
return nil
}

// MockUploader is a limited implementation of [events.MultipartUploader] that
// allows injecting errors for testing purposes. [MemoryUploader] is a more
// complete implementation and should be preferred for testing the happy path.
type MockUploader struct {
events.MultipartUploader

CreateUploadError error
ReserveUploadPartError error
ListPartsError error

MockListUploads func(ctx context.Context) ([]events.StreamUpload, error)
MockCompleteUpload func(ctx context.Context, upload events.StreamUpload, parts []events.StreamPart) error
}

func (m *MockUploader) CreateUpload(ctx context.Context, sessionID session.ID) (*events.StreamUpload, error) {
if m.CreateUploadError != nil {
return nil, m.CreateUploadError
}

return &events.StreamUpload{
ID: uuid.New().String(),
SessionID: sessionID,
}, nil
}

func (m *MockUploader) ReserveUploadPart(_ context.Context, _ events.StreamUpload, _ int64) error {
return m.ReserveUploadPartError
}

func (m *MockUploader) ListParts(_ context.Context, _ events.StreamUpload) ([]events.StreamPart, error) {
if m.ListPartsError != nil {
return nil, m.ListPartsError
}

return []events.StreamPart{}, nil
}

func (m *MockUploader) ListUploads(ctx context.Context) ([]events.StreamUpload, error) {
if m.MockListUploads != nil {
return m.MockListUploads(ctx)
}

return nil, nil
}

func (m *MockUploader) CompleteUpload(ctx context.Context, upload events.StreamUpload, parts []events.StreamPart) error {
if m.MockCompleteUpload != nil {
return m.MockCompleteUpload(ctx, upload, parts)
}

return nil
}
44 changes: 7 additions & 37 deletions lib/events/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestNewSliceErrors(t *testing.T) {
ctx := context.Background()
expectedErr := errors.New("test upload error")
streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{
Uploader: &mockUploader{reserveUploadPartError: expectedErr},
Uploader: &eventstest.MockUploader{ReserveUploadPartError: expectedErr},
})
require.NoError(t, err)

Expand All @@ -95,16 +95,16 @@ func TestNewStreamErrors(t *testing.T) {
t.Run("CreateAuditStream", func(t *testing.T) {
for _, tt := range []struct {
desc string
uploader *mockUploader
uploader *eventstest.MockUploader
expectedErr error
}{
{
desc: "CreateUploadError",
uploader: &mockUploader{createUploadError: expectedErr},
uploader: &eventstest.MockUploader{CreateUploadError: expectedErr},
},
{
desc: "ReserveUploadPartError",
uploader: &mockUploader{reserveUploadPartError: expectedErr},
uploader: &eventstest.MockUploader{ReserveUploadPartError: expectedErr},
},
} {
t.Run(tt.desc, func(t *testing.T) {
Expand All @@ -126,16 +126,16 @@ func TestNewStreamErrors(t *testing.T) {
t.Run("ResumeAuditStream", func(t *testing.T) {
for _, tt := range []struct {
desc string
uploader *mockUploader
uploader *eventstest.MockUploader
expectedErr error
}{
{
desc: "ListPartsError",
uploader: &mockUploader{listPartsError: expectedErr},
uploader: &eventstest.MockUploader{ListPartsError: expectedErr},
},
{
desc: "ReserveUploadPartError",
uploader: &mockUploader{reserveUploadPartError: expectedErr},
uploader: &eventstest.MockUploader{ReserveUploadPartError: expectedErr},
},
} {
t.Run(tt.desc, func(t *testing.T) {
Expand Down Expand Up @@ -194,36 +194,6 @@ func TestProtoStreamLargeEvent(t *testing.T) {
require.NoError(t, stream.Complete(ctx))
}

type mockUploader struct {
events.MultipartUploader
createUploadError error
reserveUploadPartError error
listPartsError error
}

func (m *mockUploader) CreateUpload(ctx context.Context, sessionID session.ID) (*events.StreamUpload, error) {
if m.createUploadError != nil {
return nil, m.createUploadError
}

return &events.StreamUpload{
ID: uuid.New().String(),
SessionID: sessionID,
}, nil
}

func (m *mockUploader) ReserveUploadPart(_ context.Context, _ events.StreamUpload, _ int64) error {
return m.reserveUploadPartError
}

func (m *mockUploader) ListParts(_ context.Context, _ events.StreamUpload) ([]events.StreamPart, error) {
if m.listPartsError != nil {
return nil, m.listPartsError
}

return []events.StreamPart{}, nil
}

func makeQueryEvent(id string, query string) *apievents.DatabaseSessionQuery {
return &apievents.DatabaseSessionQuery{
Metadata: apievents.Metadata{
Expand Down