diff --git a/cli/data/get.go b/cli/data/get.go index 6850eb1..bd2f5ad 100644 --- a/cli/data/get.go +++ b/cli/data/get.go @@ -14,9 +14,9 @@ package data import ( - "bytes" "fmt" "io" + "io/ioutil" "log" "os" "path/filepath" @@ -40,10 +40,11 @@ var ( getSubdir string ) +var s3RetriesSleep = 10 * time.Second + const ( s3ParallelGets = 100 s3Retries = 10 - s3RetriesSleep = 10 * time.Second ) var getCmd = &cobra.Command{ @@ -111,17 +112,26 @@ func copyPathToDestination(source S3Path, destination string, keys []string, sub } func readHEAD(session *session.Session, source S3Path) string { + tempFile, err := ioutil.TempFile("", "HEAD") + if err != nil { + exitErrorf("Unable to create temp file: %v", err) + } + + defer os.Remove(tempFile.Name()) + svc := s3.New(session) - out, err := getObject(svc, aws.String(source.bucket), aws.String(source.path)) + err = copyS3ObjectToFile(svc, source, source.path, tempFile) + if err != nil { + exitErrorf("Error copying HEAD: %v", err) + } + contents, err := ioutil.ReadFile(tempFile.Name()) if err != nil { - exitErrorf("Error reading HEAD: %v", err) + exitErrorf("Error reading HEAD file: %v", err) } - buf := new(bytes.Buffer) - buf.ReadFrom(out.Body) - return buf.String() + return string(contents) } func parseDestination(destination string, subdir string) string { @@ -219,58 +229,91 @@ func process(s3Client *s3.S3, src S3Path, basePath string, filePath string, sem return } - out, err := getObject(s3Client, aws.String(src.bucket), &filePath) + destination := basePath + "/" + strings.TrimPrefix(filePath, src.Dirname()+"/") + file, err := createFile(destination) if err != nil { exitErrorf("%v", err) } - defer out.Body.Close() - target := basePath + "/" + strings.TrimPrefix(filePath, src.Dirname()+"/") - err = store(out, target) + defer file.Close() + + err = copyS3ObjectToFile(s3Client, src, filePath, file) if err != nil { exitErrorf("%v", err) } } -func getObject(s3Client *s3.S3, bucket *string, key *string) (*s3.GetObjectOutput, error) { - var ( - err error - out *s3.GetObjectOutput - ) +type S3Getter interface { + GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) +} + +func copyS3ObjectToFile(s3Client S3Getter, src S3Path, filePath string, file *os.File) error { + var err error retries := s3Retries for retries > 0 { - out, err = s3Client.GetObject(&s3.GetObjectInput{ - Bucket: bucket, - Key: key, - }) + err = tryGetObject(s3Client, aws.String(src.bucket), &filePath, file) if err == nil { - return out, nil + // we're done + return nil } + resetErr := resetFileForWriting(file) + if resetErr != nil { + fmt.Printf("Unable to download object from S3 (%s) and unable reset temp file to try again (%s)", + err, + resetErr) + return errors.Wrapf(resetErr, "unable to reset temp file %s", file.Name()) + } retries-- if retries > 0 { - fmt.Printf("Error fetching from S3: %s, (%s); will retry in %v... \n", *key, err.Error(), s3RetriesSleep) + fmt.Printf("Error fetching from S3: %s, (%s); will retry in %v... \n", filePath, err.Error(), s3RetriesSleep) time.Sleep(s3RetriesSleep) } } - return nil, err + return err } -func store(obj *s3.GetObjectOutput, destination string) error { - err := os.MkdirAll(filepath.Dir(destination), 0777) +func resetFileForWriting(file *os.File) error { + err := file.Truncate(0) + _, err = file.Seek(0, 0) + return err +} + +func tryGetObject(s3Client S3Getter, bucket *string, key *string, file *os.File) error { + out, err := s3Client.GetObject(&s3.GetObjectInput{ + Bucket: bucket, + Key: key, + }) - file, err := os.Create(destination) if err != nil { - return errors.Wrapf(err, "creating destination %s", destination) + return err } - defer file.Close() + defer out.Body.Close() + + return storeS3ObjectToFile(out, file) +} + +func storeS3ObjectToFile(obj *s3.GetObjectOutput, file *os.File) error { bytes, err := io.Copy(file, obj.Body) if err != nil { - return errors.Wrapf(err, "copying file %s", destination) + return errors.Wrapf(err, "copying file %s", file.Name()) } - fmt.Printf("%s -> %d bytes\n", destination, bytes) + fmt.Printf("%s -> %d bytes\n", file.Name(), bytes) return nil } + +func createFile(destination string) (*os.File, error) { + err := os.MkdirAll(filepath.Dir(destination), 0777) + if err != nil { + return nil, errors.Wrapf(err, "creating directory %s", filepath.Dir(destination)) + } + + file, err := os.Create(destination) + if err != nil { + return nil, errors.Wrapf(err, "creating destination %s", destination) + } + return file, nil +} diff --git a/cli/data/get_test.go b/cli/data/get_test.go index 7b7de55..b88c9b7 100644 --- a/cli/data/get_test.go +++ b/cli/data/get_test.go @@ -1,8 +1,13 @@ package data import ( + "errors" "github.com/aws/aws-sdk-go/service/s3" + "io" + "io/ioutil" + "strings" "testing" + "time" ) func TestFilterObjects(t *testing.T) { @@ -47,7 +52,7 @@ func TestFilterObjectsWithNoKeys(t *testing.T) { func TestFilterObjectsUsingNonExistentKeys(t *testing.T) { var ( - key = "path/f1.csv" + key = "path/f1.csv" obj = &s3.Object{Key: &key} s3Path = S3Path{bucket: "bucket", path: "path/"} keys = []string{"f2.csv", "f3.csv"} @@ -62,3 +67,165 @@ func TestFilterObjectsUsingNonExistentKeys(t *testing.T) { t.Error("It should return an error") } } + +type s3GetterFromString struct { + s string +} + +func (s3FromString s3GetterFromString) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + out := s3.GetObjectOutput{ + Body: ioutil.NopCloser(strings.NewReader(s3FromString.s)), + } + return &out, nil +} + +func Test_copyS3ObjectToFile_worksFirstTime(t *testing.T) { + var s3Client S3Getter = s3GetterFromString{"foobar"} + + s3Path := S3Path{bucket: "bucket", path: "path/"} + filePath := "foo/bar" + tempFile, _ := ioutil.TempFile("", "testDownload") + + err := copyS3ObjectToFile(s3Client, s3Path, filePath, tempFile) + if err != nil { + t.Errorf("Should have downloaded file successfully but didn't: %v", err) + } + + bytes, err := ioutil.ReadFile(tempFile.Name()) + if err != nil { + t.Errorf("Should be able to read from 'downloaded' file but couldn't %v", err) + } + + if string(bytes) != "foobar" { + t.Errorf("File contents were incorrect. Expected '%s' but got '%s'", "foobar", string(bytes)) + } +} + +type s3FailingGetter struct { +} + +func (s3FailingGetter *s3FailingGetter) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + return nil, errors.New("can't connect to S3") +} + +func Test_copyS3ObjectToFile_failsToGetObjectFromS3(t *testing.T) { + var s3Client S3Getter = &s3FailingGetter{} + s3RetriesSleep = 1 * time.Second + + s3Path := S3Path{bucket: "bucket", path: "path/"} + filePath := "foo/bar" + tempFile, _ := ioutil.TempFile("", "testDownload") + + err := copyS3ObjectToFile(s3Client, s3Path, filePath, tempFile) + if err == nil { + t.Errorf("Shouldn't have been able to download file successfully but did") + } +} + +type s3FailingReader struct { +} + +func (s3FailingReader *s3FailingReader) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + out := s3.GetObjectOutput{ + Body: ioutil.NopCloser(&failingReader{}), + } + return &out, nil +} + +type failingReader struct { +} + +func (r *failingReader) Read(p []byte) (int, error) { + return 0, errors.New("failing reader") +} + +func Test_copyS3ObjectToFile_failsToReadFromS3(t *testing.T) { + var s3Client S3Getter = &s3FailingReader{} + s3RetriesSleep = 1 * time.Second + + s3Path := S3Path{bucket: "bucket", path: "path/"} + filePath := "foo/bar" + tempFile, _ := ioutil.TempFile("", "testDownload") + + err := copyS3ObjectToFile(s3Client, s3Path, filePath, tempFile) + if err == nil { + t.Errorf("Shouldn't have been able to download file successfully but did") + } +} + +type s3GetterFailOnClose struct { + s string +} + +func (s3GetterFailOnClose *s3GetterFailOnClose) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + out := s3.GetObjectOutput{ + Body: failOnClose{strings.NewReader(s3GetterFailOnClose.s)}, + } + return &out, nil +} + +type failOnClose struct { + io.Reader +} + +func (failOnClose) Close() error { + return errors.New("failed while closing") +} + +func Test_copyS3ObjectToFile_failsWhenClosingStream(t *testing.T) { + var s3Client S3Getter = &s3FailingReader{} + s3RetriesSleep = 1 * time.Second + + s3Path := S3Path{bucket: "bucket", path: "path/"} + filePath := "foo/bar" + tempFile, _ := ioutil.TempFile("", "testDownload") + + err := copyS3ObjectToFile(s3Client, s3Path, filePath, tempFile) + if err == nil { + t.Errorf("Shouldn't have been able to download file successfully but did") + } +} + +type s3GetterFailsFirstFewAttempts struct { + unsuccessfulReads int + s string +} + +func (s3GetterFailsFirstFewAttempts *s3GetterFailsFirstFewAttempts) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + var out s3.GetObjectOutput + if s3GetterFailsFirstFewAttempts.unsuccessfulReads == 0 { + out = s3.GetObjectOutput{ + Body: ioutil.NopCloser(strings.NewReader(s3GetterFailsFirstFewAttempts.s)), + } + } else { + s3GetterFailsFirstFewAttempts.unsuccessfulReads-- + out = s3.GetObjectOutput{ + Body: ioutil.NopCloser(&failingReader{}), + } + } + + return &out, nil +} + +func Test_copyS3ObjectToFile_failsFirstFewReadAttemptsButRetries(t *testing.T) { + var s3Client S3Getter = &s3GetterFailsFirstFewAttempts{5, "foobar"} + s3RetriesSleep = 1 * time.Second + + s3Path := S3Path{bucket: "bucket", path: "path/"} + filePath := "foo/bar" + tempFile, _ := ioutil.TempFile("", "testDownload") + + err := copyS3ObjectToFile(s3Client, s3Path, filePath, tempFile) + if err != nil { + t.Errorf("Should have downloaded file successfully but didn't: %v", err) + } + + bytes, err := ioutil.ReadFile(tempFile.Name()) + if err != nil { + t.Errorf("Should be able to read from 'downloaded' file but couldn't %v", err) + } + + if string(bytes) != "foobar" { + t.Errorf("File contents were incorrect. Expected '%s' but got '%s'", "foobar", string(bytes)) + } +}