From 61c25ae607b46621131ba600951a28942eeacc93 Mon Sep 17 00:00:00 2001 From: Pedro Cunha Date: Mon, 25 Sep 2017 18:15:46 +0100 Subject: [PATCH] Better S3 interface + tests --- cli/data/get.go | 30 ++++++++++----------- cli/data/s3path.go | 23 ++++++++++++++++ cli/data/s3path_test.go | 59 +++++++++++++++++++++++++++++++++++++++++ cli/data/s3source.go | 20 -------------- 4 files changed, 96 insertions(+), 36 deletions(-) create mode 100644 cli/data/s3path.go create mode 100644 cli/data/s3path_test.go delete mode 100644 cli/data/s3source.go diff --git a/cli/data/get.go b/cli/data/get.go index a3bd501..5c55772 100644 --- a/cli/data/get.go +++ b/cli/data/get.go @@ -45,13 +45,12 @@ $ paddle data get -b experimental trained-model/version1 dest/path exitErrorf("Bucket not defined. Please define 'bucket' in your config file.") } - source := S3Source{ + source := S3Path{ bucket: viper.GetString("bucket"), - prefix: fmt.Sprintf("%s/%s", args[0], getBranch), - path: getCommitPath, + path: fmt.Sprintf("%s/%s/%s", args[0], getBranch, getCommitPath), } - copyPathToDestination(&source, args[1]) + copyPathToDestination(source, args[1]) }, } @@ -60,7 +59,7 @@ func init() { getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)") } -func copyPathToDestination(source *S3Source, destination string) { +func copyPathToDestination(source S3Path, destination string) { session := session.Must(session.NewSessionWithOptions(session.Options{ SharedConfigState: session.SharedConfigEnable, })) @@ -68,22 +67,21 @@ func copyPathToDestination(source *S3Source, destination string) { /* * HEAD contains the path to latest folder */ - if source.path == "HEAD" { - source = source.copy() - source.path = readHEAD(session, source) + if source.Basename() == "HEAD" { + latestFolder := readHEAD(session, source) + source.path = strings.Replace(source.path, "HEAD", latestFolder, 1) } - fmt.Println("Copying " + source.fullPath() + " to " + destination) + fmt.Println("Copying " + source.path + " to " + destination) copy(session, source, destination) } -func readHEAD(session *session.Session, source *S3Source) string { +func readHEAD(session *session.Session, source S3Path) string { svc := s3.New(session) - key := fmt.Sprintf("%s/HEAD", source.prefix) out, err := svc.GetObject(&s3.GetObjectInput{ Bucket: aws.String(source.bucket), - Key: aws.String(key), + Key: aws.String(source.path), }) if err != nil { @@ -95,10 +93,10 @@ func readHEAD(session *session.Session, source *S3Source) string { return buf.String() } -func copy(session *session.Session, source *S3Source, destination string) { +func copy(session *session.Session, source S3Path, destination string) { query := &s3.ListObjectsV2Input{ Bucket: aws.String(source.bucket), - Prefix: aws.String(source.prefix + "/" + source.path), + Prefix: aws.String(source.path), } svc := s3.New(session) @@ -119,7 +117,7 @@ func copy(session *session.Session, source *S3Source, destination string) { } } -func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source *S3Source, destination string) { +func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, destination string) { for _, key := range objects { destFilename := *key.Key if strings.HasSuffix(*key.Key, "/") { @@ -133,7 +131,7 @@ func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source *S3Source, d if err != nil { exitErrorf("%v", err) } - destFilePath := destination + "/" + strings.TrimPrefix(destFilename, source.prefix + "/") + destFilePath := destination + "/" + strings.TrimPrefix(destFilename, source.Dirname() + "/") err = os.MkdirAll(filepath.Dir(destFilePath), 0777) fmt.Print(destFilePath) destFile, err := os.Create(destFilePath) diff --git a/cli/data/s3path.go b/cli/data/s3path.go new file mode 100644 index 0000000..9cdfcd4 --- /dev/null +++ b/cli/data/s3path.go @@ -0,0 +1,23 @@ +package data + +import ( + "strings" +) + +type S3Path struct { + bucket string + path string +} + +func (p *S3Path) Basename() string { + components := strings.Split(p.path, "/") + return components[len(components)-1] +} + +func (p *S3Path) Dirname() string { + components := strings.Split(p.path, "/") + if len(components) == 0 { + return "" + } + return strings.Join(components[:len(components)-1], "/") +} diff --git a/cli/data/s3path_test.go b/cli/data/s3path_test.go new file mode 100644 index 0000000..29d8be1 --- /dev/null +++ b/cli/data/s3path_test.go @@ -0,0 +1,59 @@ +package data + +import "testing" + +func TestBasename(t *testing.T) { + path := S3Path{ + bucket: "foo", + path: "aaa/bbb/ccc", + } + + dirname := path.Basename() + expectation := "ccc" + + if dirname != expectation { + t.Errorf("Basename was incorrect, got: %s, want: %s.", dirname, expectation) + } +} + +func TestBasenameWithEmptyPath(t *testing.T) { + path := S3Path{ + bucket: "foo", + path: "", + } + + dirname := path.Basename() + expectation := "" + + if dirname != expectation { + t.Errorf("Basename was incorrect, got: %s, want: %s.", dirname, expectation) + } +} + +func TestDirname(t *testing.T) { + path := S3Path{ + bucket: "foo", + path: "aaa/bbb/ccc", + } + + dirname := path.Dirname() + expectation := "aaa/bbb" + + if dirname != expectation { + t.Errorf("Dirname was incorrect, got: %s, want: %s.", dirname, expectation) + } +} + +func TestDirnameOnBasicPath(t *testing.T) { + path := S3Path{ + bucket: "foo", + path: "aaa", + } + + dirname := path.Dirname() + expectation := "" + + if dirname != expectation { + t.Errorf("Dirname was incorrect, got: %s, want: %s.", dirname, expectation) + } +} diff --git a/cli/data/s3source.go b/cli/data/s3source.go deleted file mode 100644 index b7abe54..0000000 --- a/cli/data/s3source.go +++ /dev/null @@ -1,20 +0,0 @@ -package data - -import ( - "fmt" -) - -type S3Source struct { - bucket string - prefix string - path string -} - -func (s *S3Source) copy() *S3Source { - clone := *s - return &clone -} - -func (t *S3Source) fullPath() string { - return fmt.Sprintf("%s/%s/%s", t.bucket, t.prefix, t.path); -}