Skip to content

Commit

Permalink
feat: file proxy handle
Browse files Browse the repository at this point in the history
  • Loading branch information
xhofe committed Jun 28, 2022
1 parent d1efec4 commit 96380a5
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 80 deletions.
2 changes: 1 addition & 1 deletion internal/fs/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func copyFileBetween2Accounts(tsk *task.Task[uint64], srcAccount, dstAccount dri
if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath)
}
link, err := operations.Link(tsk.Ctx, srcAccount, srcFilePath, model.LinkArgs{})
link, _, err := operations.Link(tsk.Ctx, srcAccount, srcFilePath, model.LinkArgs{})
if err != nil {
return errors.WithMessagef(err, "failed get [%s] link", srcFilePath)
}
Expand Down
8 changes: 4 additions & 4 deletions internal/fs/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ func Get(ctx context.Context, path string) (model.Obj, error) {
return res, nil
}

func Link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, error) {
res, err := link(ctx, path, args)
func Link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
res, file, err := link(ctx, path, args)
if err != nil {
log.Errorf("failed link %s: %+v", path, err)
return nil, err
return nil, nil, err
}
return res, nil
return res, file, nil
}

func MakeDir(ctx context.Context, path string) error {
Expand Down
4 changes: 2 additions & 2 deletions internal/fs/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (
"github.com/pkg/errors"
)

func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, error) {
func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
account, actualPath, err := operations.GetAccountAndActualPath(path)
if err != nil {
return nil, errors.WithMessage(err, "failed get account")
return nil, nil, errors.WithMessage(err, "failed get account")
}
return operations.Link(ctx, account, actualPath, args)
}
20 changes: 10 additions & 10 deletions internal/operations/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,19 @@ var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16))
var linkG singleflight.Group[*model.Link]

// Link get link, if is an url. should have an expiry time
func Link(ctx context.Context, account driver.Driver, path string, args model.LinkArgs) (*model.Link, error) {
func Link(ctx context.Context, account driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
file, err := Get(ctx, account, path)
if err != nil {
return nil, nil, errors.WithMessage(err, "failed to get file")
}
if file.IsDir() {
return nil, nil, errors.WithStack(errs.NotFile)
}
key := stdpath.Join(account.GetAccount().VirtualPath, path)
if link, ok := linkCache.Get(key); ok {
return link, nil
return link, file, nil
}
fn := func() (*model.Link, error) {
file, err := Get(ctx, account, path)
if err != nil {
return nil, errors.WithMessage(err, "failed to get file")
}
if file.IsDir() {
return nil, errors.WithStack(errs.NotFile)
}
link, err := account.Link(ctx, file, args)
if err != nil {
return nil, errors.WithMessage(err, "failed get link")
Expand All @@ -138,7 +138,7 @@ func Link(ctx context.Context, account driver.Driver, path string, args model.Li
return link, nil
}
link, err, _ := linkG.Do(key, fn)
return link, err
return link, file, err
}

func MakeDir(ctx context.Context, account driver.Driver, path string) error {
Expand Down
9 changes: 9 additions & 0 deletions internal/sign/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ import (
var once sync.Once
var instance sign.Sign

func Sign(data string) string {
expire := setting.GetIntSetting("link_expiration", 0)
if expire == 0 {
return NotExpired(data)
} else {
return WithDuration(data, time.Duration(expire)*time.Hour)
}
}

func WithDuration(data string, d time.Duration) string {
once.Do(Instance)
return instance.Sign(data, time.Now().Add(d).Unix())
Expand Down
13 changes: 13 additions & 0 deletions server/common/sign.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package common

import (
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/sign"
)

func Sign(obj model.Obj) string {
if obj.IsDir() {
return ""
}
return sign.Sign(obj.GetName())
}
107 changes: 60 additions & 47 deletions server/controllers/down.go
Original file line number Diff line number Diff line change
@@ -1,101 +1,114 @@
package controllers

import (
"fmt"
"github.com/alist-org/alist/v3/internal/sign"
stdpath "path"
"strings"

"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/internal/sign"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)

func Down(c *gin.Context) {
rawPath := parsePath(c.Param("path"))
rawPath := c.MustGet("path").(string)
filename := stdpath.Base(rawPath)
meta, err := db.GetNearestMeta(rawPath)
account, err := fs.GetAccount(rawPath)
if err != nil {
if !errors.Is(errors.Cause(err), errs.MetaNotFound) {
common.ErrorResp(c, err, 500, true)
return
}
common.ErrorResp(c, err, 500)
return
}
// verify sign
if needSign(meta, rawPath) {
s := c.Param("sign")
err = sign.Verify(filename, s)
if shouldProxy(account, filename) {
Proxy(c)
return
} else {
link, _, err := fs.Link(c, rawPath, model.LinkArgs{
IP: c.ClientIP(),
Header: c.Request.Header,
})
if err != nil {
common.ErrorResp(c, err, 401)
common.ErrorResp(c, err, 500)
return
}
c.Redirect(302, link.URL)
}
}

func Proxy(c *gin.Context) {
rawPath := c.MustGet("path").(string)
filename := stdpath.Base(rawPath)
account, err := fs.GetAccount(rawPath)
if err != nil {
common.ErrorResp(c, err, 500)
return
}
if needProxy(account, filename) {
link, err := fs.Link(c, rawPath, model.LinkArgs{
if canProxy(account, filename) {
downProxyUrl := account.GetAccount().DownProxyUrl
if downProxyUrl != "" {
_, ok := c.GetQuery("d")
if ok {
URL := fmt.Sprintf("%s%s?sign=%s", strings.Split(downProxyUrl, "\n")[0], rawPath, sign.Sign(filename))
c.Redirect(302, URL)
return
}
}
link, file, err := fs.Link(c, rawPath, model.LinkArgs{
Header: c.Request.Header,
})
if err != nil {
common.ErrorResp(c, err, 500)
return
}
obj, err := fs.Get(c, rawPath)
if err != nil {
common.ErrorResp(c, err, 500)
return
}
err = common.Proxy(c.Writer, c.Request, link, obj)
err = common.Proxy(c.Writer, c.Request, link, file)
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
} else {
link, err := fs.Link(c, rawPath, model.LinkArgs{
IP: c.ClientIP(),
Header: c.Request.Header,
})
if err != nil {
common.ErrorResp(c, err, 500)
return
}
c.Redirect(302, link.URL)
common.ErrorStrResp(c, "proxy not allowed", 403)
return
}
}

// TODO: implement
// path maybe contains # ? etc.
func parsePath(path string) string {
return utils.StandardizePath(path)
}

func needSign(meta *model.Meta, path string) bool {
if meta == nil || meta.Password == "" {
return false
// TODO need optimize
// when should be proxy?
// 1. config.MustProxy()
// 2. account.WebProxy
// 3. proxy_types
func shouldProxy(account driver.Driver, filename string) bool {
if account.Config().MustProxy() || account.GetAccount().WebProxy {
return true
}
if !meta.SubFolder && path != meta.Path {
return false
proxyTypes := setting.GetByKey("proxy_types")
if strings.Contains(proxyTypes, utils.Ext(filename)) {
return true
}
return true
return false
}

func needProxy(account driver.Driver, filename string) bool {
config := account.Config()
if config.MustProxy() {
// TODO need optimize
// when can be proxy?
// 1. text file
// 2. config.MustProxy()
// 3. account.WebProxy
// 4. proxy_types
// solution: text_file + shouldProxy()
func canProxy(account driver.Driver, filename string) bool {
if account.Config().MustProxy() || account.GetAccount().WebProxy {
return true
}
proxyTypes := setting.GetByKey("proxy_types")
if strings.Contains(proxyTypes, utils.Ext(filename)) {
return true
}
textTypes := setting.GetByKey("text_types")
if strings.Contains(textTypes, utils.Ext(filename)) {
return true
}
return false
}
2 changes: 1 addition & 1 deletion server/controllers/fsget.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func FsGet(c *gin.Context) {
Size: obj.GetSize(),
IsDir: obj.IsDir(),
Modified: obj.ModTime(),
Sign: Sign(obj),
Sign: common.Sign(obj),
},
// TODO: set raw url
})
Expand Down
15 changes: 1 addition & 14 deletions server/controllers/fslist.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/internal/sign"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -108,20 +107,8 @@ func toObjResp(objs []model.Obj, path string, baseURL string) []ObjResp {
Size: obj.GetSize(),
IsDir: obj.IsDir(),
Modified: obj.ModTime(),
Sign: Sign(obj),
Sign: common.Sign(obj),
})
}
return resp
}

func Sign(obj model.Obj) string {
if obj.IsDir() {
return ""
}
expire := setting.GetIntSetting("link_expiration", 0)
if expire == 0 {
return sign.NotExpired(obj.GetName())
} else {
return sign.WithDuration(obj.GetName(), time.Duration(expire)*time.Hour)
}
}
54 changes: 54 additions & 0 deletions server/middlewares/down.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package middlewares

import (
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/sign"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
stdpath "path"
)

func Down(c *gin.Context) {
rawPath := parsePath(c.Param("path"))
c.Set("path", rawPath)
filename := stdpath.Base(rawPath)
meta, err := db.GetNearestMeta(rawPath)
if err != nil {
if !errors.Is(errors.Cause(err), errs.MetaNotFound) {
common.ErrorResp(c, err, 500, true)
return
}
}
c.Set("meta", meta)
// verify sign
if needSign(meta, rawPath) {
s := c.Param("sign")
err = sign.Verify(filename, s)
if err != nil {
common.ErrorResp(c, err, 401)
c.Abort()
return
}
}
c.Next()
}

// TODO: implement
// path maybe contains # ? etc.
func parsePath(path string) string {
return utils.StandardizePath(path)
}

func needSign(meta *model.Meta, path string) bool {
if meta == nil || meta.Password == "" {
return false
}
if !meta.SubFolder && path != meta.Path {
return false
}
return true
}
3 changes: 2 additions & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ func Init(r *gin.Engine) {
common.SecretKey = []byte(conf.Conf.JwtSecret)
Cors(r)

r.GET("/d/*path", controllers.Down)
r.GET("/d/*path", middlewares.Down, controllers.Down)
r.GET("/p/*path", middlewares.Down, controllers.Proxy)

api := r.Group("/api", middlewares.Auth)
api.POST("/auth/login", controllers.Login)
Expand Down

0 comments on commit 96380a5

Please sign in to comment.