Skip to content
This repository has been archived by the owner on Dec 9, 2022. It is now read-only.

Commit

Permalink
Better S3 interface + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Cunha committed Sep 25, 2017
1 parent 48fd3a6 commit 61c25ae
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 36 deletions.
30 changes: 14 additions & 16 deletions cli/data/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ $ paddle data get -b experimental trained-model/version1 dest/path
exitErrorf("Bucket not defined. Please define 'bucket' in your config file.")
}

source := S3Source{
source := S3Path{
bucket: viper.GetString("bucket"),
prefix: fmt.Sprintf("%s/%s", args[0], getBranch),
path: getCommitPath,
path: fmt.Sprintf("%s/%s/%s", args[0], getBranch, getCommitPath),
}

copyPathToDestination(&source, args[1])
copyPathToDestination(source, args[1])
},
}

Expand All @@ -60,30 +59,29 @@ func init() {
getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)")
}

func copyPathToDestination(source *S3Source, destination string) {
func copyPathToDestination(source S3Path, destination string) {
session := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

/*
* HEAD contains the path to latest folder
*/
if source.path == "HEAD" {
source = source.copy()
source.path = readHEAD(session, source)
if source.Basename() == "HEAD" {
latestFolder := readHEAD(session, source)
source.path = strings.Replace(source.path, "HEAD", latestFolder, 1)
}

fmt.Println("Copying " + source.fullPath() + " to " + destination)
fmt.Println("Copying " + source.path + " to " + destination)
copy(session, source, destination)
}

func readHEAD(session *session.Session, source *S3Source) string {
func readHEAD(session *session.Session, source S3Path) string {
svc := s3.New(session)
key := fmt.Sprintf("%s/HEAD", source.prefix)

out, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String(source.bucket),
Key: aws.String(key),
Key: aws.String(source.path),
})

if err != nil {
Expand All @@ -95,10 +93,10 @@ func readHEAD(session *session.Session, source *S3Source) string {
return buf.String()
}

func copy(session *session.Session, source *S3Source, destination string) {
func copy(session *session.Session, source S3Path, destination string) {
query := &s3.ListObjectsV2Input{
Bucket: aws.String(source.bucket),
Prefix: aws.String(source.prefix + "/" + source.path),
Prefix: aws.String(source.path),
}
svc := s3.New(session)

Expand All @@ -119,7 +117,7 @@ func copy(session *session.Session, source *S3Source, destination string) {
}
}

func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source *S3Source, destination string) {
func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, destination string) {
for _, key := range objects {
destFilename := *key.Key
if strings.HasSuffix(*key.Key, "/") {
Expand All @@ -133,7 +131,7 @@ func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source *S3Source, d
if err != nil {
exitErrorf("%v", err)
}
destFilePath := destination + "/" + strings.TrimPrefix(destFilename, source.prefix + "/")
destFilePath := destination + "/" + strings.TrimPrefix(destFilename, source.Dirname() + "/")
err = os.MkdirAll(filepath.Dir(destFilePath), 0777)
fmt.Print(destFilePath)
destFile, err := os.Create(destFilePath)
Expand Down
23 changes: 23 additions & 0 deletions cli/data/s3path.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package data

import (
"strings"
)

type S3Path struct {
bucket string
path string
}

func (p *S3Path) Basename() string {
components := strings.Split(p.path, "/")
return components[len(components)-1]
}

func (p *S3Path) Dirname() string {
components := strings.Split(p.path, "/")
if len(components) == 0 {
return ""
}
return strings.Join(components[:len(components)-1], "/")
}
59 changes: 59 additions & 0 deletions cli/data/s3path_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package data

import "testing"

func TestBasename(t *testing.T) {
path := S3Path{
bucket: "foo",
path: "aaa/bbb/ccc",
}

dirname := path.Basename()
expectation := "ccc"

if dirname != expectation {
t.Errorf("Basename was incorrect, got: %s, want: %s.", dirname, expectation)
}
}

func TestBasenameWithEmptyPath(t *testing.T) {
path := S3Path{
bucket: "foo",
path: "",
}

dirname := path.Basename()
expectation := ""

if dirname != expectation {
t.Errorf("Basename was incorrect, got: %s, want: %s.", dirname, expectation)
}
}

func TestDirname(t *testing.T) {
path := S3Path{
bucket: "foo",
path: "aaa/bbb/ccc",
}

dirname := path.Dirname()
expectation := "aaa/bbb"

if dirname != expectation {
t.Errorf("Dirname was incorrect, got: %s, want: %s.", dirname, expectation)
}
}

func TestDirnameOnBasicPath(t *testing.T) {
path := S3Path{
bucket: "foo",
path: "aaa",
}

dirname := path.Dirname()
expectation := ""

if dirname != expectation {
t.Errorf("Dirname was incorrect, got: %s, want: %s.", dirname, expectation)
}
}
20 changes: 0 additions & 20 deletions cli/data/s3source.go

This file was deleted.

0 comments on commit 61c25ae

Please sign in to comment.