From 8bf1d6ba65c092448b92730f5a44cb70ee29c061 Mon Sep 17 00:00:00 2001 From: Pedro Cunha Date: Mon, 25 Sep 2017 16:01:08 +0100 Subject: [PATCH] Refactor of get --- cli/data/get.go | 99 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 33 deletions(-) diff --git a/cli/data/get.go b/cli/data/get.go index 24e0b8b..eb6b857 100644 --- a/cli/data/get.go +++ b/cli/data/get.go @@ -27,6 +27,21 @@ import ( "strings" ) +type S3Target struct { + bucket string + prefix string + path string +} + +func (s *S3Target) copy() *S3Target { + clone := *s + return &clone +} + +func (t *S3Target) fullPath() string { + return fmt.Sprintf("%s/%s/%s", t.bucket, t.prefix, t.path); +} + var getBranch string var getCommitPath string @@ -44,7 +59,14 @@ $ paddle data get -b experimental trained-model/version1 dest/path if !viper.IsSet("bucket") { exitErrorf("Bucket not defined. Please define 'bucket' in your config file.") } - fetchPath(viper.GetString("bucket"), args[0], getBranch, getCommitPath, args[1]) + + source := S3Target{ + bucket: viper.GetString("bucket"), + prefix: fmt.Sprintf("%s/%s", args[0], getBranch), + path: getCommitPath, + } + + copyPathToDestination(&source, args[1]) }, } @@ -53,69 +75,80 @@ func init() { getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)") } -func fetchPath(bucket string, version string, branch string, path string, destination string) { - sess := session.Must(session.NewSessionWithOptions(session.Options{ +func copyPathToDestination(source *S3Target, destination string) { + session := session.Must(session.NewSessionWithOptions(session.Options{ SharedConfigState: session.SharedConfigEnable, })) - if path == "HEAD" { - svc := s3.New(sess) - headPath := fmt.Sprintf("%s/%s/HEAD", version, branch) - fmt.Println(headPath) - out, err := svc.GetObject(&s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(headPath), - }) - if err != nil { - exitErrorf("%v", err) - } - buf := new(bytes.Buffer) - buf.ReadFrom(out.Body) - path = buf.String() - } else { - path = fmt.Sprintf("%s/%s/%s", version, branch, path) + /* + * HEAD contains the path to latest folder + */ + if source.path == "HEAD" { + source = source.copy() + source.path = readHEAD(session, source) + } + + fmt.Println("Copying " + source.fullPath() + " to " + destination) + copy(session, source, destination) +} + +func readHEAD(session *session.Session, source *S3Target) string { + svc := s3.New(session) + key := fmt.Sprintf("%s/HEAD", source.prefix) + + out, err := svc.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(source.bucket), + Key: aws.String(key), + }) + + if err != nil { + exitErrorf("%v", err) } - fmt.Println("Fetching " + path) - getBucketObjects(sess, bucket, path, destination) + + buf := new(bytes.Buffer) + buf.ReadFrom(out.Body) + return buf.String() } -func getBucketObjects(sess *session.Session, bucket string, prefix string, dest string) { +func copy(session *session.Session, source *S3Target, destination string) { query := &s3.ListObjectsV2Input{ - Bucket: aws.String(bucket), - Prefix: aws.String(prefix), + Bucket: aws.String(source.bucket), + Prefix: aws.String(source.prefix + "/" + source.path), } - svc := s3.New(sess) + svc := s3.New(session) truncatedListing := true for truncatedListing { - resp, err := svc.ListObjectsV2(query) + response, err := svc.ListObjectsV2(query) if err != nil { fmt.Println(err.Error()) return } - getObjectsAll(bucket, resp, svc, prefix, dest) - query.ContinuationToken = resp.NextContinuationToken - truncatedListing = *resp.IsTruncated + copyToLocalFiles(svc, response.Contents, source, destination) + + // Check if more results + query.ContinuationToken = response.NextContinuationToken + truncatedListing = *response.IsTruncated } } -func getObjectsAll(bucket string, bucketObjectsList *s3.ListObjectsV2Output, s3Client *s3.S3, prefix string, dest string) { - for _, key := range bucketObjectsList.Contents { +func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source *S3Target, destination string) { + for _, key := range objects { destFilename := *key.Key if strings.HasSuffix(*key.Key, "/") { fmt.Println("Got a directory") continue } out, err := s3Client.GetObject(&s3.GetObjectInput{ - Bucket: aws.String(bucket), + Bucket: aws.String(source.bucket), Key: key.Key, }) if err != nil { exitErrorf("%v", err) } - destFilePath := dest + "/" + strings.TrimPrefix(destFilename, prefix+"/") + destFilePath := destination + "/" + strings.TrimPrefix(destFilename, source.prefix + "/") err = os.MkdirAll(filepath.Dir(destFilePath), 0777) fmt.Print(destFilePath) destFile, err := os.Create(destFilePath)