From 96380a50da715a64cf6ca1026a82334940eb416d Mon Sep 17 00:00:00 2001 From: Noah Hsu Date: Tue, 28 Jun 2022 21:58:46 +0800 Subject: [PATCH] feat: file proxy handle --- internal/fs/copy.go | 2 +- internal/fs/fs.go | 8 +-- internal/fs/link.go | 4 +- internal/operations/fs.go | 20 +++---- internal/sign/sign.go | 9 +++ server/common/sign.go | 13 +++++ server/controllers/down.go | 107 ++++++++++++++++++++--------------- server/controllers/fsget.go | 2 +- server/controllers/fslist.go | 15 +---- server/middlewares/down.go | 54 ++++++++++++++++++ server/router.go | 3 +- 11 files changed, 157 insertions(+), 80 deletions(-) create mode 100644 server/common/sign.go create mode 100644 server/middlewares/down.go diff --git a/internal/fs/copy.go b/internal/fs/copy.go index 50fe85a1b17..1e53f3f76df 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -85,7 +85,7 @@ func copyFileBetween2Accounts(tsk *task.Task[uint64], srcAccount, dstAccount dri if err != nil { return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) } - link, err := operations.Link(tsk.Ctx, srcAccount, srcFilePath, model.LinkArgs{}) + link, _, err := operations.Link(tsk.Ctx, srcAccount, srcFilePath, model.LinkArgs{}) if err != nil { return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) } diff --git a/internal/fs/fs.go b/internal/fs/fs.go index 56a8a3226a8..c898789aeb2 100644 --- a/internal/fs/fs.go +++ b/internal/fs/fs.go @@ -30,13 +30,13 @@ func Get(ctx context.Context, path string) (model.Obj, error) { return res, nil } -func Link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, error) { - res, err := link(ctx, path, args) +func Link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { + res, file, err := link(ctx, path, args) if err != nil { log.Errorf("failed link %s: %+v", path, err) - return nil, err + return nil, nil, err } - return res, nil + return res, file, nil } func MakeDir(ctx context.Context, path string) error { diff --git a/internal/fs/link.go b/internal/fs/link.go index 3b4cd680624..9b971a99578 100644 --- a/internal/fs/link.go +++ b/internal/fs/link.go @@ -7,10 +7,10 @@ import ( "github.com/pkg/errors" ) -func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, error) { +func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { account, actualPath, err := operations.GetAccountAndActualPath(path) if err != nil { - return nil, errors.WithMessage(err, "failed get account") + return nil, nil, errors.WithMessage(err, "failed get account") } return operations.Link(ctx, account, actualPath, args) } diff --git a/internal/operations/fs.go b/internal/operations/fs.go index a9183a1cdf6..161d5c49aea 100644 --- a/internal/operations/fs.go +++ b/internal/operations/fs.go @@ -115,19 +115,19 @@ var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16)) var linkG singleflight.Group[*model.Link] // Link get link, if is an url. should have an expiry time -func Link(ctx context.Context, account driver.Driver, path string, args model.LinkArgs) (*model.Link, error) { +func Link(ctx context.Context, account driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { + file, err := Get(ctx, account, path) + if err != nil { + return nil, nil, errors.WithMessage(err, "failed to get file") + } + if file.IsDir() { + return nil, nil, errors.WithStack(errs.NotFile) + } key := stdpath.Join(account.GetAccount().VirtualPath, path) if link, ok := linkCache.Get(key); ok { - return link, nil + return link, file, nil } fn := func() (*model.Link, error) { - file, err := Get(ctx, account, path) - if err != nil { - return nil, errors.WithMessage(err, "failed to get file") - } - if file.IsDir() { - return nil, errors.WithStack(errs.NotFile) - } link, err := account.Link(ctx, file, args) if err != nil { return nil, errors.WithMessage(err, "failed get link") @@ -138,7 +138,7 @@ func Link(ctx context.Context, account driver.Driver, path string, args model.Li return link, nil } link, err, _ := linkG.Do(key, fn) - return link, err + return link, file, err } func MakeDir(ctx context.Context, account driver.Driver, path string) error { diff --git a/internal/sign/sign.go b/internal/sign/sign.go index a8fb68233a7..4e2d62f3ae3 100644 --- a/internal/sign/sign.go +++ b/internal/sign/sign.go @@ -10,6 +10,15 @@ import ( var once sync.Once var instance sign.Sign +func Sign(data string) string { + expire := setting.GetIntSetting("link_expiration", 0) + if expire == 0 { + return NotExpired(data) + } else { + return WithDuration(data, time.Duration(expire)*time.Hour) + } +} + func WithDuration(data string, d time.Duration) string { once.Do(Instance) return instance.Sign(data, time.Now().Add(d).Unix()) diff --git a/server/common/sign.go b/server/common/sign.go new file mode 100644 index 00000000000..5b509eda0b3 --- /dev/null +++ b/server/common/sign.go @@ -0,0 +1,13 @@ +package common + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/sign" +) + +func Sign(obj model.Obj) string { + if obj.IsDir() { + return "" + } + return sign.Sign(obj.GetName()) +} diff --git a/server/controllers/down.go b/server/controllers/down.go index 81afbe19016..14a82b00c65 100644 --- a/server/controllers/down.go +++ b/server/controllers/down.go @@ -1,101 +1,114 @@ package controllers import ( + "fmt" + "github.com/alist-org/alist/v3/internal/sign" stdpath "path" "strings" - "github.com/alist-org/alist/v3/internal/db" "github.com/alist-org/alist/v3/internal/driver" - "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/setting" - "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" - "github.com/pkg/errors" ) func Down(c *gin.Context) { - rawPath := parsePath(c.Param("path")) + rawPath := c.MustGet("path").(string) filename := stdpath.Base(rawPath) - meta, err := db.GetNearestMeta(rawPath) + account, err := fs.GetAccount(rawPath) if err != nil { - if !errors.Is(errors.Cause(err), errs.MetaNotFound) { - common.ErrorResp(c, err, 500, true) - return - } + common.ErrorResp(c, err, 500) + return } - // verify sign - if needSign(meta, rawPath) { - s := c.Param("sign") - err = sign.Verify(filename, s) + if shouldProxy(account, filename) { + Proxy(c) + return + } else { + link, _, err := fs.Link(c, rawPath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, + }) if err != nil { - common.ErrorResp(c, err, 401) + common.ErrorResp(c, err, 500) return } + c.Redirect(302, link.URL) } +} + +func Proxy(c *gin.Context) { + rawPath := c.MustGet("path").(string) + filename := stdpath.Base(rawPath) account, err := fs.GetAccount(rawPath) if err != nil { common.ErrorResp(c, err, 500) return } - if needProxy(account, filename) { - link, err := fs.Link(c, rawPath, model.LinkArgs{ + if canProxy(account, filename) { + downProxyUrl := account.GetAccount().DownProxyUrl + if downProxyUrl != "" { + _, ok := c.GetQuery("d") + if ok { + URL := fmt.Sprintf("%s%s?sign=%s", strings.Split(downProxyUrl, "\n")[0], rawPath, sign.Sign(filename)) + c.Redirect(302, URL) + return + } + } + link, file, err := fs.Link(c, rawPath, model.LinkArgs{ Header: c.Request.Header, }) if err != nil { common.ErrorResp(c, err, 500) return } - obj, err := fs.Get(c, rawPath) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - err = common.Proxy(c.Writer, c.Request, link, obj) + err = common.Proxy(c.Writer, c.Request, link, file) if err != nil { common.ErrorResp(c, err, 500, true) return } } else { - link, err := fs.Link(c, rawPath, model.LinkArgs{ - IP: c.ClientIP(), - Header: c.Request.Header, - }) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - c.Redirect(302, link.URL) + common.ErrorStrResp(c, "proxy not allowed", 403) + return } } -// TODO: implement -// path maybe contains # ? etc. -func parsePath(path string) string { - return utils.StandardizePath(path) -} - -func needSign(meta *model.Meta, path string) bool { - if meta == nil || meta.Password == "" { - return false +// TODO need optimize +// when should be proxy? +// 1. config.MustProxy() +// 2. account.WebProxy +// 3. proxy_types +func shouldProxy(account driver.Driver, filename string) bool { + if account.Config().MustProxy() || account.GetAccount().WebProxy { + return true } - if !meta.SubFolder && path != meta.Path { - return false + proxyTypes := setting.GetByKey("proxy_types") + if strings.Contains(proxyTypes, utils.Ext(filename)) { + return true } - return true + return false } -func needProxy(account driver.Driver, filename string) bool { - config := account.Config() - if config.MustProxy() { +// TODO need optimize +// when can be proxy? +// 1. text file +// 2. config.MustProxy() +// 3. account.WebProxy +// 4. proxy_types +// solution: text_file + shouldProxy() +func canProxy(account driver.Driver, filename string) bool { + if account.Config().MustProxy() || account.GetAccount().WebProxy { return true } proxyTypes := setting.GetByKey("proxy_types") if strings.Contains(proxyTypes, utils.Ext(filename)) { return true } + textTypes := setting.GetByKey("text_types") + if strings.Contains(textTypes, utils.Ext(filename)) { + return true + } return false } diff --git a/server/controllers/fsget.go b/server/controllers/fsget.go index 66a1ccb4fdf..35f6c935c80 100644 --- a/server/controllers/fsget.go +++ b/server/controllers/fsget.go @@ -53,7 +53,7 @@ func FsGet(c *gin.Context) { Size: obj.GetSize(), IsDir: obj.IsDir(), Modified: obj.ModTime(), - Sign: Sign(obj), + Sign: common.Sign(obj), }, // TODO: set raw url }) diff --git a/server/controllers/fslist.go b/server/controllers/fslist.go index 1d5cc7b3eb3..fcb0ba9e4d7 100644 --- a/server/controllers/fslist.go +++ b/server/controllers/fslist.go @@ -9,7 +9,6 @@ import ( "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/setting" - "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" @@ -108,20 +107,8 @@ func toObjResp(objs []model.Obj, path string, baseURL string) []ObjResp { Size: obj.GetSize(), IsDir: obj.IsDir(), Modified: obj.ModTime(), - Sign: Sign(obj), + Sign: common.Sign(obj), }) } return resp } - -func Sign(obj model.Obj) string { - if obj.IsDir() { - return "" - } - expire := setting.GetIntSetting("link_expiration", 0) - if expire == 0 { - return sign.NotExpired(obj.GetName()) - } else { - return sign.WithDuration(obj.GetName(), time.Duration(expire)*time.Hour) - } -} diff --git a/server/middlewares/down.go b/server/middlewares/down.go new file mode 100644 index 00000000000..ef1583e0681 --- /dev/null +++ b/server/middlewares/down.go @@ -0,0 +1,54 @@ +package middlewares + +import ( + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + stdpath "path" +) + +func Down(c *gin.Context) { + rawPath := parsePath(c.Param("path")) + c.Set("path", rawPath) + filename := stdpath.Base(rawPath) + meta, err := db.GetNearestMeta(rawPath) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + // verify sign + if needSign(meta, rawPath) { + s := c.Param("sign") + err = sign.Verify(filename, s) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + } + c.Next() +} + +// TODO: implement +// path maybe contains # ? etc. +func parsePath(path string) string { + return utils.StandardizePath(path) +} + +func needSign(meta *model.Meta, path string) bool { + if meta == nil || meta.Password == "" { + return false + } + if !meta.SubFolder && path != meta.Path { + return false + } + return true +} diff --git a/server/router.go b/server/router.go index 23066e3753e..c46584c151d 100644 --- a/server/router.go +++ b/server/router.go @@ -13,7 +13,8 @@ func Init(r *gin.Engine) { common.SecretKey = []byte(conf.Conf.JwtSecret) Cors(r) - r.GET("/d/*path", controllers.Down) + r.GET("/d/*path", middlewares.Down, controllers.Down) + r.GET("/p/*path", middlewares.Down, controllers.Proxy) api := r.Group("/api", middlewares.Auth) api.POST("/auth/login", controllers.Login)