diff --git a/drivers/all.go b/drivers/all.go index 1a9f42dd199..9fe2faaa2b9 100644 --- a/drivers/all.go +++ b/drivers/all.go @@ -33,6 +33,7 @@ import ( _ "github.com/alist-org/alist/v3/drivers/teambition" _ "github.com/alist-org/alist/v3/drivers/terabox" _ "github.com/alist-org/alist/v3/drivers/thunder" + _ "github.com/alist-org/alist/v3/drivers/trainbit" _ "github.com/alist-org/alist/v3/drivers/uss" _ "github.com/alist-org/alist/v3/drivers/virtual" _ "github.com/alist-org/alist/v3/drivers/webdav" diff --git a/drivers/trainbit/driver.go b/drivers/trainbit/driver.go new file mode 100644 index 00000000000..ee71ba352cc --- /dev/null +++ b/drivers/trainbit/driver.go @@ -0,0 +1,142 @@ +package trainbit + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "net/url" + "strings" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" +) + +type Trainbit struct { + model.Storage + Addition +} + +var apiExpiredate, guid string + +func (d *Trainbit) Config() driver.Config { + return config +} + +func (d *Trainbit) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Trainbit) Init(ctx context.Context) error { + http.DefaultClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + var err error + apiExpiredate, guid, err = getToken(d.ApiKey, d.AUSHELLPORTAL) + if err != nil { + return err + } + return nil +} + +func (d *Trainbit) Drop(ctx context.Context) error { + return nil +} + +func (d *Trainbit) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + form := make(url.Values) + form.Set("parentid", strings.Split(dir.GetID(), "_")[0]) + res, err := postForm("https://trainbit.com/lib/api/v1/listoffiles", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + if err != nil { + return nil, err + } + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + var jsonData any + json.Unmarshal(data, &jsonData) + if err != nil { + return nil, err + } + object, err := parseRawFileObject(jsonData.(map[string]any)["items"].([]any)) + if err != nil { + return nil, err + } + return object, nil +} + +func (d *Trainbit) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + res, err := get(fmt.Sprintf("https://trainbit.com/files/%s/", strings.Split(file.GetID(), "_")[0]), d.ApiKey, d.AUSHELLPORTAL) + if err != nil { + return nil, err + } + return &model.Link{ + URL: res.Header.Get("Location"), + }, nil +} + +func (d *Trainbit) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + form := make(url.Values) + form.Set("name", local2provider(dirName, true)) + form.Set("parentid", strings.Split(parentDir.GetID(), "_")[0]) + _, err := postForm("https://trainbit.com/lib/api/v1/createfolder", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + form := make(url.Values) + form.Set("sourceid", strings.Split(srcObj.GetID(), "_")[0]) + form.Set("destinationid", strings.Split(dstDir.GetID(), "_")[0]) + _, err := postForm("https://trainbit.com/lib/api/v1/move", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + form := make(url.Values) + form.Set("id", strings.Split(srcObj.GetID(), "_")[0]) + form.Set("name", local2provider(newName, srcObj.IsDir())) + _, err := postForm("https://trainbit.com/lib/api/v1/edit", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotImplement +} + +func (d *Trainbit) Remove(ctx context.Context, obj model.Obj) error { + form := make(url.Values) + form.Set("id", strings.Split(obj.GetID(), "_")[0]) + _, err := postForm("https://trainbit.com/lib/api/v1/delete", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + endpoint, _ := url.Parse("https://tb28.trainbit.com/api/upload/send_raw/") + query := &url.Values{} + query.Add("q", strings.Split(dstDir.GetID(), "_")[1]) + query.Add("guid", guid) + query.Add("name", url.QueryEscape(local2provider(stream.GetName(), false))) + endpoint.RawQuery = query.Encode() + var total int64 + total = 0 + progressReader := &ProgressReader{ + stream, + func(byteNum int) { + total += int64(byteNum) + up(int(math.Round(float64(total) / float64(stream.GetSize()) * 100))) + }, + } + req, err := http.NewRequest(http.MethodPost, endpoint.String(), progressReader) + if err != nil { + return err + } + req.Header.Set("Content-Type", "text/json; charset=UTF-8") + _, err = http.DefaultClient.Do(req) + return err +} + +var _ driver.Driver = (*Trainbit)(nil) diff --git a/drivers/trainbit/meta.go b/drivers/trainbit/meta.go new file mode 100644 index 00000000000..59c09d77e1c --- /dev/null +++ b/drivers/trainbit/meta.go @@ -0,0 +1,29 @@ +package trainbit + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + AUSHELLPORTAL string `json:"AUSHELLPORTAL" required:"true"` + ApiKey string `json:"apikey" required:"true"` +} + +var config = driver.Config{ + Name: "Trainbit", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "0_000", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Trainbit{} + }) +} diff --git a/drivers/trainbit/types.go b/drivers/trainbit/types.go new file mode 100644 index 00000000000..4de1a0abdf3 --- /dev/null +++ b/drivers/trainbit/types.go @@ -0,0 +1 @@ +package trainbit \ No newline at end of file diff --git a/drivers/trainbit/util.go b/drivers/trainbit/util.go new file mode 100644 index 00000000000..87568b6166e --- /dev/null +++ b/drivers/trainbit/util.go @@ -0,0 +1,150 @@ +package trainbit + +import ( + "io" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type ProgressReader struct { + io.Reader + reporter func(byteNum int) +} + +func (progressReader *ProgressReader) Read(data []byte) (int, error) { + byteNum, err := progressReader.Reader.Read(data) + progressReader.reporter(byteNum) + return byteNum, err +} + +func get(url string, apiKey string, AUSHELLPORTAL string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.AddCookie(&http.Cookie{ + Name: ".AUSHELLPORTAL", + Value: AUSHELLPORTAL, + MaxAge: 2 * 60, + }) + req.AddCookie(&http.Cookie{ + Name: "retkeyapi", + Value: apiKey, + MaxAge: 2 * 60, + }) + res, err := http.DefaultClient.Do(req) + return res, err +} + +func postForm(endpoint string, data url.Values, apiExpiredate string, apiKey string, AUSHELLPORTAL string) (*http.Response, error) { + extData := make(url.Values) + for key, value := range data { + extData[key] = make([]string, len(value)) + copy(extData[key], value) + } + extData.Set("apikey", apiKey) + extData.Set("expiredate", apiExpiredate) + req, err := http.NewRequest(http.MethodPost, endpoint, strings.NewReader(extData.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{ + Name: ".AUSHELLPORTAL", + Value: AUSHELLPORTAL, + MaxAge: 2 * 60, + }) + req.AddCookie(&http.Cookie{ + Name: "retkeyapi", + Value: apiKey, + MaxAge: 2 * 60, + }) + res, err := http.DefaultClient.Do(req) + return res, err +} + +func getToken(apiKey string, AUSHELLPORTAL string) (string, string, error) { + res, err := get("https://trainbit.com/files/", apiKey, AUSHELLPORTAL) + if err != nil { + return "", "", err + } + data, err := io.ReadAll(res.Body) + if err != nil { + return "", "", err + } + text := string(data) + apiExpiredateReg := regexp.MustCompile(`core.api.expiredate = '([^']*)';`) + result := apiExpiredateReg.FindAllStringSubmatch(text, -1) + apiExpiredate := result[0][1] + guidReg := regexp.MustCompile(`app.vars.upload.guid = '([^']*)';`) + result = guidReg.FindAllStringSubmatch(text, -1) + guid := result[0][1] + return apiExpiredate, guid, nil +} + +func local2provider(filename string, isFolder bool) string { + filename = strings.Replace(filename, "%", url.QueryEscape("%"), -1) + filename = strings.Replace(filename, "/", url.QueryEscape("/"), -1) + filename = strings.Replace(filename, ":", url.QueryEscape(":"), -1) + filename = strings.Replace(filename, "*", url.QueryEscape("*"), -1) + filename = strings.Replace(filename, "?", url.QueryEscape("?"), -1) + filename = strings.Replace(filename, "\"", url.QueryEscape("\""), -1) + filename = strings.Replace(filename, "<", url.QueryEscape("<"), -1) + filename = strings.Replace(filename, ">", url.QueryEscape(">"), -1) + filename = strings.Replace(filename, "|", url.QueryEscape("|"), -1) + if isFolder { + return filename + } + return strings.Join([]string{filename, ".delete_suffix."}, "") +} + +func provider2local(filename string) string { + index := strings.LastIndex(filename, ".delete_suffix.") + if index != -1 { + filename = filename[:index] + } + rawName := strings.Replace(filename, url.QueryEscape("/"), "/", -1) + rawName = strings.Replace(rawName, url.QueryEscape(":"), ":", -1) + rawName = strings.Replace(rawName, url.QueryEscape("*"), "*", -1) + rawName = strings.Replace(rawName, url.QueryEscape("?"), "?", -1) + rawName = strings.Replace(rawName, url.QueryEscape("\""), "\"", -1) + rawName = strings.Replace(rawName, url.QueryEscape("<"), "<", -1) + rawName = strings.Replace(rawName, url.QueryEscape(">"), ">", -1) + rawName = strings.Replace(rawName, url.QueryEscape("|"), "|", -1) + rawName = strings.Replace(rawName, url.QueryEscape("%"), "%", -1) + return rawName +} + +func parseRawFileObject(rawObject []any) ([]model.Obj, error) { + objectList := make([]model.Obj, 0) + for _, each := range rawObject { + object := each.(map[string]any) + if object["id"].(string) == "0" { + continue + } + isFolder := int64(object["ty"].(float64)) == 1 + var name string + if isFolder { + name = object["name"].(string) + } else { + name = strings.Join([]string{object["name"].(string), object["ext"].(string)}, ".") + } + modified, err := time.Parse("2006/01/02 15:04:05", object["modified"].(string)) + if err != nil { + return nil, err + } + objectList = append(objectList, model.Obj(&model.Object{ + ID: strings.Join([]string{object["id"].(string), strings.Split(object["uploadurl"].(string), "=")[1]}, "_"), + Name: provider2local(name), + Size: int64(object["byte"].(float64)), + Modified: modified.Add(-210 * time.Minute), + IsFolder: isFolder, + })) + } + return objectList, nil +}