diff --git a/lib/events/filesessions/filestream.go b/lib/events/filesessions/filestream.go index b9755602e9dca..75d187618f2ee 100644 --- a/lib/events/filesessions/filestream.go +++ b/lib/events/filesessions/filestream.go @@ -170,6 +170,11 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload return trace.Wrap(err) } + // If there are no parts to complete, move to cleanup + if len(parts) == 0 { + return h.cleanupUpload(ctx, upload) + } + uploadPath := h.recordingPath(upload.SessionID) // Prevent other processes from accessing this file until the write is completed @@ -251,6 +256,22 @@ Loop: return nil } +func (h *Handler) cleanupUpload(ctx context.Context, upload events.StreamUpload) error { + uploadKey := h.recordingPath(upload.SessionID) + log := h.logger.With( + "upload", upload.ID, + "session", upload.SessionID, + "key", uploadKey, + ) + log.DebugContext(ctx, "Aborting upload") + if err := os.RemoveAll(h.uploadRootPath(upload)); err != nil { + h.logger.ErrorContext(ctx, "Failed to remove upload", "upload_id", upload.ID) + } + + log.InfoContext(ctx, "Aborted upload") + return nil +} + // ListParts lists upload parts func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([]events.StreamPart, error) { var parts []events.StreamPart diff --git a/lib/events/filesessions/filestream_test.go b/lib/events/filesessions/filestream_test.go index 0994e80a9e72a..3e8e4b797b38e 100644 --- a/lib/events/filesessions/filestream_test.go +++ b/lib/events/filesessions/filestream_test.go @@ -128,14 +128,6 @@ func TestCompleteUpload(t *testing.T) { createPart(t, handler, upload, int64(5), []byte("withreservation")) }, }, - { - desc: "OnlyReservation", - expectedContent: []byte{}, - partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) { - createPart(t, handler, upload, int64(1), []byte{}) - createPart(t, handler, upload, int64(2), []byte{}) - }, - }, } { t.Run(test.desc, func(t *testing.T) { handler, err := NewHandler(Config{ @@ -169,3 +161,48 @@ func TestCompleteUpload(t *testing.T) { }) } } + +func TestCleanupEmptyUpload(t *testing.T) { + ctx := t.Context() + + handler, err := NewHandler(Config{ + Directory: t.TempDir(), + OpenFile: os.OpenFile, + }) + require.NoError(t, err) + + sessionID := session.NewID() + + // Create a completed upload. + upload, err := handler.CreateUpload(ctx, sessionID) + require.NoError(t, err) + + err = handler.ReserveUploadPart(ctx, *upload, 1) + require.NoError(t, err) + + content := []byte("hello world") + part, err := handler.UploadPart(ctx, *upload, 1, bytes.NewReader(content)) + require.NoError(t, err) + + err = handler.CompleteUpload(ctx, *upload, []events.StreamPart{*part}) + require.NoError(t, err) + + // Create an empty upload with the same session ID and try to complete it. + emptyUpload, err := handler.CreateUpload(ctx, sessionID) + require.NoError(t, err) + + err = handler.CompleteUpload(ctx, *emptyUpload, []events.StreamPart{}) + require.NoError(t, err) + + // The empty upload should be cleaned up without impacting the original completed upload. + uploadPath := handler.recordingPath(upload.SessionID) + f, err := os.Open(uploadPath) + require.NoError(t, err) + + gotContent, err := io.ReadAll(f) + require.NoError(t, err) + require.Equal(t, content, gotContent) + + require.NoDirExists(t, handler.uploadRootPath(*upload)) + require.NoDirExists(t, handler.uploadRootPath(*emptyUpload)) +}