From c89a462d0c3cb33946832b72a91f123abb24d8e6 Mon Sep 17 00:00:00 2001 From: Noah Hsu Date: Sat, 3 Sep 2022 21:38:43 +0800 Subject: [PATCH] feat: add s3 driver --- drivers/all.go | 1 + drivers/s3/driver.go | 160 +++++++++++++++++++++++++++++++++ drivers/s3/meta.go | 33 +++++++ drivers/s3/types.go | 1 + drivers/s3/util.go | 178 +++++++++++++++++++++++++++++++++++++ drivers/template/driver.go | 2 +- 6 files changed, 374 insertions(+), 1 deletion(-) create mode 100644 drivers/s3/driver.go create mode 100644 drivers/s3/meta.go create mode 100644 drivers/s3/types.go create mode 100644 drivers/s3/util.go diff --git a/drivers/all.go b/drivers/all.go index cbd1f9785da..4a50b64155d 100644 --- a/drivers/all.go +++ b/drivers/all.go @@ -9,6 +9,7 @@ import ( _ "github.com/alist-org/alist/v3/drivers/onedrive" _ "github.com/alist-org/alist/v3/drivers/pikpak" _ "github.com/alist-org/alist/v3/drivers/quark" + _ "github.com/alist-org/alist/v3/drivers/s3" _ "github.com/alist-org/alist/v3/drivers/teambition" _ "github.com/alist-org/alist/v3/drivers/virtual" ) diff --git a/drivers/s3/driver.go b/drivers/s3/driver.go new file mode 100644 index 00000000000..8307628e60a --- /dev/null +++ b/drivers/s3/driver.go @@ -0,0 +1,160 @@ +package s3 + +import ( + "bytes" + "context" + "fmt" + "io" + "net/url" + stdpath "path" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + log "github.com/sirupsen/logrus" +) + +type S3 struct { + model.Storage + Addition + Session *session.Session + client *s3.S3 + linkClient *s3.S3 +} + +func (d *S3) Config() driver.Config { + return config +} + +func (d *S3) GetAddition() driver.Additional { + return d.Addition +} + +func (d *S3) Init(ctx context.Context, storage model.Storage) error { + d.Storage = storage + err := utils.Json.UnmarshalFromString(d.Storage.Addition, &d.Addition) + if err != nil { + return err + } + if d.Region == "" { + d.Region = "alist" + } + err = d.initSession() + if err != nil { + return err + } + d.client = d.getClient(false) + d.linkClient = d.getClient(true) + return nil +} + +func (d *S3) Drop(ctx context.Context) error { + return nil +} + +func (d *S3) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if d.ListObjectVersion == "v2" { + return d.listV2(dir.GetPath()) + } + return d.listV1(dir.GetPath()) +} + +//func (d *S3) Get(ctx context.Context, path string) (model.Obj, error) { +// // this is optional +// return nil, errs.NotImplement +//} + +func (d *S3) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + path := getKey(file.GetPath(), false) + disposition := fmt.Sprintf(`attachment;filename="%s"`, url.QueryEscape(stdpath.Base(path))) + input := &s3.GetObjectInput{ + Bucket: &d.Bucket, + Key: &path, + //ResponseContentDisposition: &disposition, + } + if d.CustomHost == "" { + input.ResponseContentDisposition = &disposition + } + req, _ := d.linkClient.GetObjectRequest(input) + var link string + var err error + if d.CustomHost != "" { + err = req.Build() + link = req.HTTPRequest.URL.String() + } else { + link, err = req.Presign(time.Hour * time.Duration(d.SignURLExpire)) + } + if err != nil { + return nil, err + } + return &model.Link{ + URL: link, + }, nil +} + +func (d *S3) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return d.Put(ctx, &model.Object{ + Path: stdpath.Join(parentDir.GetPath(), dirName), + }, &model.FileStream{ + Obj: &model.Object{ + Name: getPlaceholderName(d.Placeholder), + Modified: time.Now(), + }, + ReadCloser: io.NopCloser(bytes.NewReader([]byte{})), + Mimetype: "application/octet-stream", + }, func(int) {}) +} + +func (d *S3) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + err := d.Copy(ctx, srcObj, dstDir) + if err != nil { + return err + } + return d.Remove(ctx, srcObj) +} + +func (d *S3) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + err := d.copy(srcObj.GetPath(), stdpath.Join(stdpath.Dir(srcObj.GetPath()), newName), srcObj.IsDir()) + if err != nil { + return err + } + return d.Remove(ctx, srcObj) +} + +func (d *S3) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return d.copy(srcObj.GetPath(), stdpath.Join(dstDir.GetPath(), stdpath.Base(srcObj.GetPath())), srcObj.IsDir()) +} + +func (d *S3) Remove(ctx context.Context, obj model.Obj) error { + key := getKey(obj.GetPath(), obj.IsDir()) + input := &s3.DeleteObjectInput{ + Bucket: &d.Bucket, + Key: &key, + } + _, err := d.client.DeleteObject(input) + return err +} + +func (d *S3) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + uploader := s3manager.NewUploader(d.Session) + key := getKey(stdpath.Join(dstDir.GetPath(), stream.GetName()), false) + log.Debugln("key:", key) + input := &s3manager.UploadInput{ + Bucket: &d.Bucket, + Key: &key, + Body: stream, + } + _, err := uploader.Upload(input) + return err +} + +func (d *S3) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + return nil, errs.NotSupport +} + +var _ driver.Driver = (*S3)(nil) diff --git a/drivers/s3/meta.go b/drivers/s3/meta.go new file mode 100644 index 00000000000..782d81ec2fb --- /dev/null +++ b/drivers/s3/meta.go @@ -0,0 +1,33 @@ +package s3 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootFolderPath + Bucket string `json:"bucket" required:"true"` + Endpoint string `json:"endpoint" required:"true"` + Region string `json:"region"` + AccessKeyID string `json:"access_key_id" required:"true"` + SecretAccessKey string `json:"secret_access_key" required:"true"` + CustomHost string `json:"custom_host"` + SignURLExpire int `json:"sign_url_expire" type:"number" default:"4"` + Placeholder string `json:"placeholder"` + ForcePathStyle bool `json:"force_path_style"` + ListObjectVersion string `json:"list_object_version" type:"select" options:"v1,v2" default:"v1"` +} + +var config = driver.Config{ + Name: "S3", + LocalSort: true, +} + +func New() driver.Driver { + return &S3{} +} + +func init() { + op.RegisterDriver(config, New) +} diff --git a/drivers/s3/types.go b/drivers/s3/types.go new file mode 100644 index 00000000000..3ed7f97237d --- /dev/null +++ b/drivers/s3/types.go @@ -0,0 +1 @@ +package s3 diff --git a/drivers/s3/util.go b/drivers/s3/util.go new file mode 100644 index 00000000000..35ad898142a --- /dev/null +++ b/drivers/s3/util.go @@ -0,0 +1,178 @@ +package s3 + +import ( + "errors" + "net/http" + "path" + "strings" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +func (d *S3) initSession() error { + cfg := &aws.Config{ + Credentials: credentials.NewStaticCredentials(d.AccessKeyID, d.SecretAccessKey, ""), + Region: &d.Region, + Endpoint: &d.Endpoint, + S3ForcePathStyle: aws.Bool(d.ForcePathStyle), + } + var err error + d.Session, err = session.NewSession(cfg) + return err +} + +func (d *S3) getClient(link bool) *s3.S3 { + client := s3.New(d.Session) + if link && d.CustomHost != "" { + client.Handlers.Build.PushBack(func(r *request.Request) { + if r.HTTPRequest.Method != http.MethodGet { + return + } + r.HTTPRequest.URL.Host = d.CustomHost + }) + } + return client +} + +func getKey(path string, dir bool) string { + path = strings.TrimPrefix(path, "/") + if path != "" && dir { + path += "/" + } + return path +} + +var defaultPlaceholderName = ".placeholder" + +func getPlaceholderName(placeholder string) string { + if placeholder == "" { + return defaultPlaceholderName + } + return placeholder +} + +func (d *S3) listV1(prefix string) ([]model.Obj, error) { + prefix = getKey(prefix, true) + log.Debugf("list: %s", prefix) + files := make([]model.Obj, 0) + marker := "" + for { + input := &s3.ListObjectsInput{ + Bucket: &d.Bucket, + Marker: &marker, + Prefix: &prefix, + Delimiter: aws.String("/"), + } + listObjectsResult, err := d.client.ListObjects(input) + if err != nil { + return nil, err + } + for _, object := range listObjectsResult.CommonPrefixes { + name := path.Base(strings.Trim(*object.Prefix, "/")) + file := model.Object{ + //Id: *object.Key, + Name: name, + Modified: d.Modified, + IsFolder: true, + } + files = append(files, &file) + } + for _, object := range listObjectsResult.Contents { + name := path.Base(*object.Key) + if name == getPlaceholderName(d.Placeholder) { + continue + } + file := model.Object{ + //Id: *object.Key, + Name: name, + Size: *object.Size, + Modified: *object.LastModified, + } + files = append(files, &file) + } + if listObjectsResult.IsTruncated == nil { + return nil, errors.New("IsTruncated nil") + } + if *listObjectsResult.IsTruncated { + marker = *listObjectsResult.NextMarker + } else { + break + } + } + return files, nil +} + +func (d *S3) listV2(prefix string) ([]model.Obj, error) { + prefix = getKey(prefix, true) + files := make([]model.Obj, 0) + var continuationToken, startAfter *string + for { + input := &s3.ListObjectsV2Input{ + Bucket: &d.Bucket, + ContinuationToken: continuationToken, + Prefix: &prefix, + Delimiter: aws.String("/"), + StartAfter: startAfter, + } + listObjectsResult, err := d.client.ListObjectsV2(input) + if err != nil { + return nil, err + } + log.Debugf("resp: %+v", listObjectsResult) + for _, object := range listObjectsResult.CommonPrefixes { + name := path.Base(strings.Trim(*object.Prefix, "/")) + file := model.Object{ + //Id: *object.Key, + Name: name, + Modified: d.Modified, + IsFolder: true, + } + files = append(files, &file) + } + for _, object := range listObjectsResult.Contents { + name := path.Base(*object.Key) + if name == getPlaceholderName(d.Placeholder) { + continue + } + file := model.Object{ + //Id: *object.Key, + Name: name, + Size: *object.Size, + Modified: *object.LastModified, + } + files = append(files, &file) + } + if !aws.BoolValue(listObjectsResult.IsTruncated) { + break + } + if listObjectsResult.NextContinuationToken != nil { + continuationToken = listObjectsResult.NextContinuationToken + continue + } + if len(listObjectsResult.Contents) == 0 { + break + } + startAfter = listObjectsResult.Contents[len(listObjectsResult.Contents)-1].Key + } + return files, nil +} + +func (d *S3) copy(src string, dst string, isDir bool) error { + srcKey := getKey(src, isDir) + dstKey := getKey(dst, isDir) + input := &s3.CopyObjectInput{ + Bucket: &d.Bucket, + CopySource: &srcKey, + Key: &dstKey, + } + _, err := d.client.CopyObject(input) + return err +} diff --git a/drivers/template/driver.go b/drivers/template/driver.go index 6f605ea7e45..bb008541e32 100644 --- a/drivers/template/driver.go +++ b/drivers/template/driver.go @@ -43,7 +43,7 @@ func (d *Template) List(ctx context.Context, dir model.Obj, args model.ListArgs) } func (d *Template) Get(ctx context.Context, path string) (model.Obj, error) { - // TODO this is optional + // this is optional return nil, errs.NotImplement }