Skip to content

Commit

Permalink
feat: add copy to task manager
Browse files Browse the repository at this point in the history
  • Loading branch information
xhofe committed Jun 17, 2022
1 parent 53e969e commit fa6e918
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 70 deletions.
7 changes: 3 additions & 4 deletions drivers/local/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ func (d Driver) Config() driver.Config {

func (d *Driver) Init(ctx context.Context, account model.Account) error {
d.Account = account
addition := d.Account.Addition
err := utils.Json.UnmarshalFromString(addition, &d.Addition)
err := utils.Json.UnmarshalFromString(d.Account.Addition, &d.Addition)
if err != nil {
return errors.Wrap(err, "error while unmarshal addition")
}
Expand All @@ -32,7 +31,7 @@ func (d *Driver) Init(ctx context.Context, account model.Account) error {
} else {
d.SetStatus("OK")
}
operations.SaveDriverAccount(d)
operations.MustSaveDriverAccount(d)
return err
}

Expand Down Expand Up @@ -79,7 +78,7 @@ func (d *Driver) Remove(ctx context.Context, obj model.Obj) error {
panic("implement me")
}

func (d *Driver) Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer) error {
func (d *Driver) Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
//TODO implement me
panic("implement me")
}
Expand Down
4 changes: 3 additions & 1 deletion internal/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,7 @@ type Writer interface {
// Remove remove `object`
Remove(ctx context.Context, obj model.Obj) error
// Put upload `stream` to `parentDir`
Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer) error
Put(ctx context.Context, parentDir model.Obj, stream model.FileStreamer, up UpdateProgress) error
}

type UpdateProgress func(percentage float64)
38 changes: 29 additions & 9 deletions internal/fs/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,55 @@ package fs

import (
"context"
"fmt"
"github.com/alist-org/alist/v3/pkg/task"
"github.com/alist-org/alist/v3/pkg/utils"
stdpath "path"

"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/operations"
"github.com/alist-org/alist/v3/internal/task"
"github.com/pkg/errors"
)

var copyTaskManager = task.NewTaskManager()

func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string) error {
srcFile, err := operations.Get(ctx, srcAccount, srcPath)
func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string, setStatus func(status string)) error {
setStatus("getting src object")
srcObj, err := operations.Get(ctx, srcAccount, srcPath)
if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcPath)
}
if srcFile.IsDir() {
if srcObj.IsDir() {
setStatus("src object is dir, listing files")
files, err := operations.List(ctx, srcAccount, srcPath)
if err != nil {
return errors.WithMessagef(err, "failed list src [%s] files", srcPath)
}
for _, file := range files {
if utils.IsCanceled(ctx) {
return nil
}
srcFilePath := stdpath.Join(srcPath, file.GetName())
dstFilePath := stdpath.Join(dstPath, file.GetName())
if err := CopyBetween2Accounts(ctx, srcAccount, dstAccount, srcFilePath, dstFilePath); err != nil {
return errors.WithMessagef(err, "failed copy file [%s] to [%s]", srcFilePath, dstFilePath)
}
copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcFilePath, dstFilePath), func(task *task.Task) error {
return CopyBetween2Accounts(ctx, srcAccount, dstAccount, srcFilePath, dstFilePath, task.SetStatus)
})
}
} else {
copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcPath, dstPath), func(task *task.Task) error {
return CopyFileBetween2Accounts(task.Ctx, srcAccount, dstAccount, srcPath, dstPath, func(percentage float64) {
task.SetStatus(fmt.Sprintf("uploading: %2.f%", percentage))
})
})
}
return nil
}

func CopyFileBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Driver, srcPath, dstPath string, up driver.UpdateProgress) error {
srcFile, err := operations.Get(ctx, srcAccount, srcPath)
if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcPath)
}
link, err := operations.Link(ctx, srcAccount, srcPath, model.LinkArgs{})
if err != nil {
Expand All @@ -39,6 +60,5 @@ func CopyBetween2Accounts(ctx context.Context, srcAccount, dstAccount driver.Dri
if err != nil {
return errors.WithMessagef(err, "failed get [%s] stream", srcPath)
}
// TODO add as task
return operations.Put(ctx, dstAccount, dstPath, stream)
return operations.Put(ctx, dstAccount, dstPath, stream, up)
}
10 changes: 5 additions & 5 deletions internal/fs/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package fs
import (
"context"
"fmt"

"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/operations"
"github.com/alist-org/alist/v3/internal/task"
"github.com/alist-org/alist/v3/pkg/task"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -49,7 +48,7 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) (
if err != nil {
return false, errors.WithMessage(err, "failed get src account")
}
dstAccount, dstActualPath, err := operations.GetAccountAndActualPath(srcPath)
dstAccount, dstActualPath, err := operations.GetAccountAndActualPath(dstPath)
if err != nil {
return false, errors.WithMessage(err, "failed get dst account")
}
Expand All @@ -60,7 +59,7 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) (
// not in an account
// TODO add status set callback to put
copyTaskManager.Add(fmt.Sprintf("copy %s to %s", srcActualPath, dstActualPath), func(task *task.Task) error {
return CopyBetween2Accounts(context.TODO(), srcAccount, dstAccount, srcActualPath, dstActualPath)
return CopyBetween2Accounts(task.Ctx, srcAccount, dstAccount, srcActualPath, dstActualPath, task.SetStatus)
})
return true, nil
}
Expand All @@ -73,10 +72,11 @@ func Remove(ctx context.Context, account driver.Driver, path string) error {
return operations.Remove(ctx, account, actualPath)
}

// Put add as a put task
func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer) error {
account, actualParentPath, err := operations.GetAccountAndActualPath(parentPath)
if err != nil {
return errors.WithMessage(err, "failed get account")
}
return operations.Put(ctx, account, actualParentPath, file)
return operations.Put(ctx, account, actualParentPath, file, nil)
}
12 changes: 10 additions & 2 deletions internal/operations/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package operations

import (
"context"
log "github.com/sirupsen/logrus"
"sort"
"strings"
"time"
Expand Down Expand Up @@ -85,8 +86,15 @@ func UpdateAccount(ctx context.Context, account model.Account) error {
return nil
}

// SaveDriverAccount call from specific driver
func SaveDriverAccount(driver driver.Driver) error {
// MustSaveDriverAccount call from specific driver
func MustSaveDriverAccount(driver driver.Driver) {
err := saveDriverAccount(driver)
if err != nil {
log.Errorf("failed save driver account: %s", err)
}
}

func saveDriverAccount(driver driver.Driver) error {
account := driver.GetAccount()
addition := driver.GetAddition()
bytes, err := utils.Json.Marshal(addition)
Expand Down
8 changes: 6 additions & 2 deletions internal/operations/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func Remove(ctx context.Context, account driver.Driver, path string) error {
return account.Remove(ctx, obj)
}

func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer) error {
func Put(ctx context.Context, account driver.Driver, parentPath string, file model.FileStreamer, up driver.UpdateProgress) error {
err := MakeDir(ctx, account, parentPath)
if err != nil {
return errors.WithMessagef(err, "failed to make dir [%s]", parentPath)
Expand All @@ -192,5 +192,9 @@ func Put(ctx context.Context, account driver.Driver, parentPath string, file mod
if err != nil {
return errors.WithMessagef(err, "failed to get dir [%s]", parentPath)
}
return account.Put(ctx, parentDir, file)
// if up is nil, set a default to prevent panic
if up == nil {
up = func(p float64) {}
}
return account.Put(ctx, parentDir, file, up)
}
36 changes: 0 additions & 36 deletions internal/task/task.go

This file was deleted.

22 changes: 11 additions & 11 deletions internal/task/manager.go → pkg/task/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,37 @@ import (
"github.com/alist-org/alist/v3/pkg/generic_sync"
)

func NewTaskManager() *TaskManager {
return &TaskManager{
func NewTaskManager() *Manager {
return &Manager{
tasks: generic_sync.MapOf[int64, *Task]{},
curID: 0,
}
}

type TaskManager struct {
type Manager struct {
curID int64
tasks generic_sync.MapOf[int64, *Task]
}

func (tm *TaskManager) AddTask(task *Task) {
func (tm *Manager) AddTask(task *Task) {
task.ID = tm.curID
atomic.AddInt64(&tm.curID, 1)
tm.tasks.Store(task.ID, task)
}

func (tm *TaskManager) GetAll() []*Task {
func (tm *Manager) GetAll() []*Task {
return tm.tasks.Values()
}

func (tm *TaskManager) Get(id int64) (*Task, bool) {
func (tm *Manager) Get(id int64) (*Task, bool) {
return tm.tasks.Load(id)
}

func (tm *TaskManager) Remove(id int64) {
func (tm *Manager) Remove(id int64) {
tm.tasks.Delete(id)
}

func (tm *TaskManager) RemoveFinished() {
func (tm *Manager) RemoveFinished() {
tasks := tm.GetAll()
for _, task := range tasks {
if task.Status == FINISHED {
Expand All @@ -45,7 +45,7 @@ func (tm *TaskManager) RemoveFinished() {
}
}

func (tm *TaskManager) RemoveError() {
func (tm *Manager) RemoveError() {
tasks := tm.GetAll()
for _, task := range tasks {
if task.Error != nil {
Expand All @@ -54,8 +54,8 @@ func (tm *TaskManager) RemoveError() {
}
}

func (tm *TaskManager) Add(name string, f Func) {
task := NewTask(name, f)
func (tm *Manager) Add(name string, f Func) {
task := newTask(name, f)
tm.AddTask(task)
go task.Run()
}
64 changes: 64 additions & 0 deletions pkg/task/task.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Package task manage task, such as file upload, file copy between accounts, offline download, etc.
package task

import (
"context"
"github.com/pkg/errors"
)

var (
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
CANCELING = "canceling"
CANCELED = "canceled"
)

type Func func(task *Task) error

type Task struct {
ID int64
Name string
Status string
Error error
Func Func
Ctx context.Context
cancel context.CancelFunc
}

func newTask(name string, func_ Func) *Task {
ctx, cancel := context.WithCancel(context.Background())
return &Task{
Name: name,
Status: PENDING,
Func: func_,
Ctx: ctx,
cancel: cancel,
}
}

func (t *Task) SetStatus(status string) {
t.Status = status
}

func (t *Task) Run() {
t.Status = RUNNING
t.Error = t.Func(t)
if errors.Is(t.Ctx.Err(), context.Canceled) {
t.Status = CANCELED
} else {
t.Status = FINISHED
}
}

func (t *Task) Retry() {
t.Run()
}

func (t *Task) Cancel() {
if t.cancel != nil {
t.cancel()
}
// maybe can't cancel
t.Status = CANCELING
}
12 changes: 12 additions & 0 deletions pkg/utils/ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package utils

import "context"

func IsCanceled(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}

0 comments on commit fa6e918

Please sign in to comment.