diff --git a/lib/events/complete.go b/lib/events/complete.go index 8e8f81469cb84..9962368c3ad64 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -201,8 +201,11 @@ func (u *UploadCompleter) CheckUploads(ctx context.Context) error { log.Debugf("upload has %d parts", len(parts)) - if err := u.cfg.Uploader.CompleteUpload(ctx, upload, parts); err != nil { - return trace.Wrap(err, "completing upload") + if err := u.cfg.Uploader.CompleteUpload(ctx, upload, parts); trace.IsNotFound(err) { + log.WithError(err).Debug("Upload not found, moving on to next upload") + continue + } else if err != nil { + return trace.Wrap(err) } log.Debug("Completed upload") completed++ diff --git a/lib/events/complete_test.go b/lib/events/complete_test.go index 16d9412e5b549..7757e4c7e5add 100644 --- a/lib/events/complete_test.go +++ b/lib/events/complete_test.go @@ -18,6 +18,7 @@ package events_test import ( "context" + "fmt" "strings" "testing" "time" @@ -149,6 +150,77 @@ func TestUploadCompleterEmitsSessionEnd(t *testing.T) { } } +func TestCheckUploadsContinuesOnError(t *testing.T) { + clock := clockwork.NewFakeClock() + expires := clock.Now().Add(time.Hour * 1) + + sessionTrackers := []types.SessionTracker{ + &types.SessionTrackerV1{ + Spec: types.SessionTrackerSpecV1{ + SessionID: string(session.NewID()), + }, + ResourceHeader: types.ResourceHeader{ + Metadata: types.Metadata{ + Expires: &expires, + }, + }, + }, + &types.SessionTrackerV1{ + Spec: types.SessionTrackerSpecV1{ + SessionID: string(session.NewID()), + }, + ResourceHeader: types.ResourceHeader{ + Metadata: types.Metadata{ + Expires: &expires, + }, + }, + }, + } + + sessionTrackerService := &mockSessionTrackerService{ + clock: clock, + trackers: sessionTrackers, + } + + var completedUploads []session.ID + uploader := &eventstest.MockUploader{ + MockCompleteUpload: func(ctx context.Context, upload events.StreamUpload, parts []events.StreamPart) error { + // simulate a not found error on the first complete upload + if upload.SessionID == session.ID(sessionTrackers[0].GetSessionID()) { + return trace.NotFound("no such upload %v", sessionTrackers[0].GetSessionID()) + } + + completedUploads = append(completedUploads, upload.SessionID) + return nil + }, + MockListUploads: func(ctx context.Context) ([]events.StreamUpload, error) { + var result []events.StreamUpload + for i, sess := range sessionTrackers { + result = append(result, events.StreamUpload{ + ID: fmt.Sprintf("upload-%v", i), + SessionID: session.ID(sess.GetSessionID()), + Initiated: clock.Now(), + }) + } + return result, nil + }, + } + + uc, err := events.NewUploadCompleter(events.UploadCompleterConfig{ + Uploader: uploader, + AuditLog: &eventstest.MockAuditLog{}, + SessionTracker: sessionTrackerService, + Clock: clock, + ClusterName: "teleport-cluster", + }) + require.NoError(t, err) + + // verify that the 2nd upload completed even though the first one failed + clock.Advance(1 * time.Hour) + uc.CheckUploads(context.Background()) + require.ElementsMatch(t, completedUploads, []session.ID{session.ID(sessionTrackers[1].GetSessionID())}) +} + type mockSessionTrackerService struct { clock clockwork.Clock trackers []types.SessionTracker diff --git a/lib/events/eventstest/uploader.go b/lib/events/eventstest/uploader.go index 6a37c0372477b..c2142ffa3d7ef 100644 --- a/lib/events/eventstest/uploader.go +++ b/lib/events/eventstest/uploader.go @@ -279,3 +279,56 @@ func (m *MemoryUploader) GetUploadMetadata(sid session.ID) events.UploadMetadata func (m *MemoryUploader) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error { return nil } + +// MockUploader is a limited implementation of [events.MultipartUploader] that +// allows injecting errors for testing purposes. [MemoryUploader] is a more +// complete implementation and should be preferred for testing the happy path. +type MockUploader struct { + events.MultipartUploader + + CreateUploadError error + ReserveUploadPartError error + ListPartsError error + + MockListUploads func(ctx context.Context) ([]events.StreamUpload, error) + MockCompleteUpload func(ctx context.Context, upload events.StreamUpload, parts []events.StreamPart) error +} + +func (m *MockUploader) CreateUpload(ctx context.Context, sessionID session.ID) (*events.StreamUpload, error) { + if m.CreateUploadError != nil { + return nil, m.CreateUploadError + } + + return &events.StreamUpload{ + ID: uuid.New().String(), + SessionID: sessionID, + }, nil +} + +func (m *MockUploader) ReserveUploadPart(_ context.Context, _ events.StreamUpload, _ int64) error { + return m.ReserveUploadPartError +} + +func (m *MockUploader) ListParts(_ context.Context, _ events.StreamUpload) ([]events.StreamPart, error) { + if m.ListPartsError != nil { + return nil, m.ListPartsError + } + + return []events.StreamPart{}, nil +} + +func (m *MockUploader) ListUploads(ctx context.Context) ([]events.StreamUpload, error) { + if m.MockListUploads != nil { + return m.MockListUploads(ctx) + } + + return nil, nil +} + +func (m *MockUploader) CompleteUpload(ctx context.Context, upload events.StreamUpload, parts []events.StreamPart) error { + if m.MockCompleteUpload != nil { + return m.MockCompleteUpload(ctx, upload, parts) + } + + return nil +} diff --git a/lib/events/stream_test.go b/lib/events/stream_test.go index c24364ec53e4d..d4ce9731ff703 100644 --- a/lib/events/stream_test.go +++ b/lib/events/stream_test.go @@ -73,7 +73,7 @@ func TestNewSliceErrors(t *testing.T) { ctx := context.Background() expectedErr := errors.New("test upload error") streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ - Uploader: &mockUploader{reserveUploadPartError: expectedErr}, + Uploader: &eventstest.MockUploader{ReserveUploadPartError: expectedErr}, }) require.NoError(t, err) @@ -95,16 +95,16 @@ func TestNewStreamErrors(t *testing.T) { t.Run("CreateAuditStream", func(t *testing.T) { for _, tt := range []struct { desc string - uploader *mockUploader + uploader *eventstest.MockUploader expectedErr error }{ { desc: "CreateUploadError", - uploader: &mockUploader{createUploadError: expectedErr}, + uploader: &eventstest.MockUploader{CreateUploadError: expectedErr}, }, { desc: "ReserveUploadPartError", - uploader: &mockUploader{reserveUploadPartError: expectedErr}, + uploader: &eventstest.MockUploader{ReserveUploadPartError: expectedErr}, }, } { t.Run(tt.desc, func(t *testing.T) { @@ -126,16 +126,16 @@ func TestNewStreamErrors(t *testing.T) { t.Run("ResumeAuditStream", func(t *testing.T) { for _, tt := range []struct { desc string - uploader *mockUploader + uploader *eventstest.MockUploader expectedErr error }{ { desc: "ListPartsError", - uploader: &mockUploader{listPartsError: expectedErr}, + uploader: &eventstest.MockUploader{ListPartsError: expectedErr}, }, { desc: "ReserveUploadPartError", - uploader: &mockUploader{reserveUploadPartError: expectedErr}, + uploader: &eventstest.MockUploader{ReserveUploadPartError: expectedErr}, }, } { t.Run(tt.desc, func(t *testing.T) { @@ -194,36 +194,6 @@ func TestProtoStreamLargeEvent(t *testing.T) { require.NoError(t, stream.Complete(ctx)) } -type mockUploader struct { - events.MultipartUploader - createUploadError error - reserveUploadPartError error - listPartsError error -} - -func (m *mockUploader) CreateUpload(ctx context.Context, sessionID session.ID) (*events.StreamUpload, error) { - if m.createUploadError != nil { - return nil, m.createUploadError - } - - return &events.StreamUpload{ - ID: uuid.New().String(), - SessionID: sessionID, - }, nil -} - -func (m *mockUploader) ReserveUploadPart(_ context.Context, _ events.StreamUpload, _ int64) error { - return m.reserveUploadPartError -} - -func (m *mockUploader) ListParts(_ context.Context, _ events.StreamUpload) ([]events.StreamPart, error) { - if m.listPartsError != nil { - return nil, m.listPartsError - } - - return []events.StreamPart{}, nil -} - func makeQueryEvent(id string, query string) *apievents.DatabaseSessionQuery { return &apievents.DatabaseSessionQuery{ Metadata: apievents.Metadata{