Skip to content

Commit

Permalink
chore(storage/transfermanager): DownloadDirectory fails if file alrea…
Browse files Browse the repository at this point in the history
…dy exists (#10507)
  • Loading branch information
BrennaEpp authored Jul 8, 2024
1 parent a428615 commit 09a467d
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 3 deletions.
68 changes: 65 additions & 3 deletions storage/transfermanager/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ func (d *Downloader) DownloadObject(ctx context.Context, input *DownloadObjectIn
// or use the callback to process the result. DownloadDirectory is thread-safe
// and can be called simultaneously from different goroutines.
// DownloadDirectory will resolve any filters on the input and create the needed
// directory structure locally as the operations progress.
// Note: DownloadDirectory overwrites existing files in the directory.
// directory structure locally. Do not modify this struture until the download
// has completed.
// DownloadDirectory will fail if any of the files it attempts to download
// already exist in the local directory.
func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirectoryInput) error {
if d.closed() {
return errors.New("transfermanager: Downloader used after WaitAndClose was called")
Expand All @@ -92,7 +94,51 @@ func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirec
return fmt.Errorf("transfermanager: DownloadDirectory query.SetAttrSelection: %w", err)
}

// TODO: Clean up any created directory structure on failure.
// Grab a snapshot of the local directory so we can return to it on error.
localDirSnapshot := make(map[string]bool) // stores all filepaths to directories in localdir
if err := filepath.WalkDir(input.LocalDirectory, func(path string, d os.DirEntry, err error) error {
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
if d.IsDir() {
localDirSnapshot[path] = true
}
return nil
}); err != nil {
return fmt.Errorf("transfermanager: local directory walkthrough failed: %w", err)
}

cleanFiles := func(inputs []DownloadObjectInput) error {
// Remove all created files.
for _, in := range inputs {
f := in.Destination.(*os.File)
f.Close()
os.Remove(f.Name())
}

// Remove all created dirs.
var removePaths []string
if err := filepath.WalkDir(input.LocalDirectory, func(path string, d os.DirEntry, err error) error {
if d.IsDir() && !localDirSnapshot[path] {
removePaths = append(removePaths, path)
// We don't need to go into subdirectories, since this directory needs to be removed.
return filepath.SkipDir
}
return err
}); err != nil {
return fmt.Errorf("transfermanager: local directory walkthrough failed: %w", err)
}

for _, path := range removePaths {
if err := os.RemoveAll(path); err != nil {
return fmt.Errorf("transfermanager: failed to remove directory: %w", err)
}
}
return nil
}

objectsToQueue := []string{}
it := d.client.Bucket(input.Bucket).Objects(ctx, query)
Expand All @@ -105,6 +151,15 @@ func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirec
return fmt.Errorf("transfermanager: DownloadDirectory failed to list objects: %w", err)
}

// Check if the file exists.
// TODO: add skip option.
filePath := filepath.Join(input.LocalDirectory, attrs.Name)
if _, err := os.Stat(filePath); err == nil {
return fmt.Errorf("transfermanager: failed to create file(%q): %w", filePath, os.ErrExist)
} else if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("transfermanager: failed to create file(%q): %w", filePath, err)
}

objectsToQueue = append(objectsToQueue, attrs.Name)
}

Expand All @@ -118,12 +173,14 @@ func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirec
// Make sure all directories in the object path exist.
err := os.MkdirAll(objDirectory, fs.ModeDir|fs.ModePerm)
if err != nil {
cleanFiles(inputs)
return fmt.Errorf("transfermanager: DownloadDirectory failed to make directory(%q): %w", objDirectory, err)
}

// Create file to download to.
f, fErr := os.Create(filePath)
if fErr != nil {
cleanFiles(inputs)
return fmt.Errorf("transfermanager: DownloadDirectory failed to create file(%q): %w", filePath, fErr)
}

Expand Down Expand Up @@ -239,6 +296,11 @@ func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutpu
result.Err = fmt.Errorf("closing file(%q): %w", f.Name(), err)
}

// Clean up the file if it failed.
if result.Err != nil {
os.Remove(f.Name())
}

if d.config.asynchronous {
input.directoryObjectOutputs <- copiedResult
}
Expand Down
81 changes: 81 additions & 0 deletions storage/transfermanager/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"hash/crc32"
"io"
"io/fs"
"log"
"math/rand"
"os"
Expand Down Expand Up @@ -296,6 +297,86 @@ func TestIntegration_DownloadDirectoryAsync(t *testing.T) {
})
}

// TestIntegration_DownloadDirectoryError tests that an error is returned if the
// file already exists and that any created file structure is deleted as well.
func TestIntegration_DownloadDirectoryError(t *testing.T) {
multiTransportTest(context.Background(), t, func(t *testing.T, ctx context.Context, c *storage.Client, tb downloadTestBucket) {
localDir := t.TempDir()
callbacks := make(chan bool)

d, err := NewDownloader(c, WithWorkers(2))
if err != nil {
t.Fatalf("NewDownloader: %v", err)
}
defer d.WaitAndClose()

// First download a nested file.
obj := "dir/objC"
if err := d.DownloadDirectory(ctx, &DownloadDirectoryInput{
Bucket: tb.bucket,
LocalDirectory: localDir,
Prefix: obj,
OnObjectDownload: func(got *DownloadOutput) {
callbacks <- true

if got.Err != nil {
t.Errorf("result.Err: %v", got.Err)
}

if got, want := got.Attrs.Size, tb.objectSizes[got.Object]; want != got {
t.Errorf("expected object size %d, got %d", want, got)
}
},
}); err != nil {
t.Fatalf("d.DownloadDirectory: %v", err)
}

// Then add another file and another directory to the temp dir.
localNestedDir := "localonly"
localFile := "localonly-file"

f, err := os.Create(filepath.Join(localDir, localFile))
if err != nil {
t.Errorf("os.Create: %v", err)
}
if err := f.Close(); err != nil {
t.Errorf("f.Close: %v", err)
}

if err := os.Mkdir(filepath.Join(localDir, localNestedDir), fs.ModeDir|fs.ModePerm); err != nil {
t.Errorf("os.Mkdir: %v", err)
}

// Now attempt to download the directory; it should fail.
<-callbacks
err = d.DownloadDirectory(ctx, &DownloadDirectoryInput{
Bucket: tb.bucket,
LocalDirectory: localDir,
})
if !errors.Is(err, os.ErrExist) {
t.Errorf("d.DownloadDirectory should have failed with error %q; got %v", os.ErrExist, err)
}

// Check the local directory, it should have the first file, the second and another directory only.
expected := []string{obj, localNestedDir, localFile}

entries, err := os.ReadDir(localDir)
if err != nil {
t.Errorf("os.ReadDir: %v", err)
}
if got, want := len(entries), len(expected); got != want {
t.Errorf("localDir does not have the expected amount of entries, got %d, want %d", got, want)
}

for _, exp := range expected {
_, err := os.Stat(filepath.Join(localDir, exp))
if err != nil {
t.Errorf("os.Stat: %v", err)
}
}
})
}

func TestIntegration_DownloaderSynchronous(t *testing.T) {
multiTransportTest(context.Background(), t, func(t *testing.T, ctx context.Context, c *storage.Client, tb downloadTestBucket) {
objects := tb.objects
Expand Down

0 comments on commit 09a467d

Please sign in to comment.