Skip to content
This repository has been archived by the owner on Dec 9, 2022. It is now read-only.

Add data commands #1

Merged
merged 2 commits into from
Jul 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ Work in progress.
```
$ paddle help
```

31 changes: 31 additions & 0 deletions cli/data/cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright © 2017 RooFoods LTD
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package data

import (
"github.com/spf13/cobra"
)

// dataCmd represents the data command
var DataCmd = &cobra.Command{
Use: "data",
Short: "Commit and retrieve data",
Long: `Commands to commit data to S3 and retrieve it.
`,
}

func init() {
DataCmd.AddCommand(commitCmd)
DataCmd.AddCommand(getCmd)
}
139 changes: 139 additions & 0 deletions cli/data/commit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright © 2017 RooFoods LTD
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package data

import (
"bytes"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/deliveroo/paddle/common"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"os"
"path/filepath"
"strings"
"time"
)

var commitBranch string

var commitCmd = &cobra.Command{
Use: "commit [source path] [version]",
Short: "Commit data to S3",
Args: cobra.ExactArgs(2),
Long: `Store data into S3 under a versioned path, and update HEAD.

Example:

$ paddle data commit -b experimental source/path trained-model/version1
`,
Run: func(cmd *cobra.Command, args []string) {
if !viper.IsSet("bucket") {
exitErrorf("Bucket not defined. Please define 'bucket' in your config file.")
}
commitPath(args[0], viper.GetString("bucket"), args[1], commitBranch)
},
}

func init() {
commitCmd.Flags().StringVarP(&commitBranch, "branch", "b", "master", "Branch to work on")
}

func commitPath(path string, bucket string, version string, branch string) {
fd, err := os.Stat(path)
if err != nil {
exitErrorf("Path %v not found", path)
}
if !fd.Mode().IsDir() {
exitErrorf("Path %v must be a directory", path)
}

hash, err := common.DirHash(path)
if err != nil {
exitErrorf("Unable to hash input folder")
}

t := time.Now().UTC()

datePath := fmt.Sprintf("%d/%02d/%02d/%02d%02d",
t.Year(), t.Month(), t.Day(),
t.Hour(), t.Minute())

destPath := fmt.Sprintf("%s/%s/%s_%s", version, branch, datePath, hash)

sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

fileList := []string{}
filepath.Walk(path, func(p string, f os.FileInfo, err error) error {
if common.IsDirectory(p) {
return nil
} else {
fileList = append(fileList, p)
return nil
}
})

uploader := s3manager.NewUploader(sess)

for _, file := range fileList {
key := destPath + "/" + strings.TrimPrefix(file, path+"/")
fmt.Println(file + " -> " + key)
uploadFileToS3(uploader, bucket, key, file)
}

// Update HEAD

headFile := fmt.Sprintf("%s/%s/HEAD", version, branch)

uploadDataToS3(sess, destPath, bucket, headFile)
}

func uploadFileToS3(uploader *s3manager.Uploader, bucketName string, key string, filePath string) {
file, err := os.Open(filePath)
if err != nil {
fmt.Println("Failed to open file", file, err)
os.Exit(1)
}
defer file.Close()

_, err = uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(bucketName),
Key: aws.String(key),
Body: file,
})

if err != nil {
exitErrorf("Failed to upload data to %s/%s, %s", bucketName, key, err.Error())
return
}
}

func uploadDataToS3(sess *session.Session, data string, bucket string, key string) {
s3Svc := s3.New(sess)

_, err := s3Svc.PutObject(&s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: bytes.NewReader([]byte(data)),
})

if err != nil {
exitErrorf("Unable to update %s", key)
}
}
11 changes: 11 additions & 0 deletions cli/data/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package data

import (
"fmt"
"os"
)

func exitErrorf(msg string, args ...interface{}) {
fmt.Fprintf(os.Stderr, msg+"\n", args...)
os.Exit(1)
}
133 changes: 133 additions & 0 deletions cli/data/get.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright © 2017 RooFoods LTD
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package data

import (
"bytes"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"io"
"os"
"path/filepath"
"strings"
)

var getBranch string
var getCommitPath string

var getCmd = &cobra.Command{
Use: "get [version] [destination path]",
Short: "Fetch data from S3",
Args: cobra.ExactArgs(2),
Long: `Fetch data from a S3 versioned path.

Example:

$ paddle data get -b experimental trained-model/version1 dest/path
`,
Run: func(cmd *cobra.Command, args []string) {
if !viper.IsSet("bucket") {
exitErrorf("Bucket not defined. Please define 'bucket' in your config file.")
}
fetchPath(viper.GetString("bucket"), args[0], getBranch, getCommitPath, args[1])
},
}

func init() {
getCmd.Flags().StringVarP(&getBranch, "branch", "b", "master", "Branch to work on")
getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)")
}

func fetchPath(bucket string, version string, branch string, path string, destination string) {
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

if path == "HEAD" {
svc := s3.New(sess)
headPath := fmt.Sprintf("%s/%s/HEAD", version, branch)
fmt.Println(headPath)
out, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(headPath),
})
if err != nil {
exitErrorf("%v", err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(out.Body)
path = buf.String()
} else {
path = fmt.Sprintf("%s/%s/%s", version, branch, path)
}
fmt.Println("Fetching " + path)
getBucketObjects(sess, bucket, path, destination)
}

func getBucketObjects(sess *session.Session, bucket string, prefix string, dest string) {
query := &s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
Prefix: aws.String(prefix),
}
svc := s3.New(sess)

truncatedListing := true

for truncatedListing {
resp, err := svc.ListObjectsV2(query)

if err != nil {
fmt.Println(err.Error())
return
}
getObjectsAll(bucket, resp, svc, prefix, dest)
query.ContinuationToken = resp.NextContinuationToken
truncatedListing = *resp.IsTruncated
}
}

func getObjectsAll(bucket string, bucketObjectsList *s3.ListObjectsV2Output, s3Client *s3.S3, prefix string, dest string) {
for _, key := range bucketObjectsList.Contents {
destFilename := *key.Key
if strings.HasSuffix(*key.Key, "/") {
fmt.Println("Got a directory")
continue
}
out, err := s3Client.GetObject(&s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: key.Key,
})
if err != nil {
exitErrorf("%v", err)
}
destFilePath := dest + "/" + strings.TrimPrefix(destFilename, prefix+"/")
err = os.MkdirAll(filepath.Dir(destFilePath), 0777)
fmt.Print(destFilePath)
destFile, err := os.Create(destFilePath)
if err != nil {
exitErrorf("%v", err)
}
bytes, err := io.Copy(destFile, out.Body)
if err != nil {
exitErrorf("%v", err)
}
fmt.Printf(" -> %d bytes\n", bytes)
out.Body.Close()
destFile.Close()
}
}
Loading