From d665cce7391815f62664ab211afd7b7e2984d0c3 Mon Sep 17 00:00:00 2001 From: Noah Hsu Date: Sat, 18 Jun 2022 20:38:14 +0800 Subject: [PATCH] feat: add task work limit --- internal/fs/copy.go | 6 ++--- internal/fs/put.go | 2 +- pkg/task/errors.go | 7 ++++++ pkg/task/manager.go | 53 +++++++++++++++++++++++++++++-------------- pkg/task/task.go | 17 +++++++++++--- pkg/task/task_test.go | 8 +++---- 6 files changed, 65 insertions(+), 28 deletions(-) create mode 100644 pkg/task/errors.go diff --git a/internal/fs/copy.go b/internal/fs/copy.go index 504eaf59201..b2bff07f426 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -32,7 +32,7 @@ func Copy(ctx context.Context, account driver.Driver, srcPath, dstPath string) ( return false, operations.Copy(ctx, account, srcActualPath, dstActualPath) } // not in an account - CopyTaskManager.Add( + CopyTaskManager.Submit( fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcActualPath, dstAccount.GetAccount().VirtualPath, dstActualPath), func(task *task.Task) error { return CopyBetween2Accounts(task, srcAccount, dstAccount, srcActualPath, dstActualPath) @@ -58,14 +58,14 @@ func CopyBetween2Accounts(t *task.Task, srcAccount, dstAccount driver.Driver, sr } srcObjPath := stdpath.Join(srcPath, obj.GetName()) dstObjPath := stdpath.Join(dstPath, obj.GetName()) - CopyTaskManager.Add( + CopyTaskManager.Submit( fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcObjPath, dstAccount.GetAccount().VirtualPath, dstObjPath), func(t *task.Task) error { return CopyBetween2Accounts(t, srcAccount, dstAccount, srcObjPath, dstObjPath) }) } } else { - CopyTaskManager.Add( + CopyTaskManager.Submit( fmt.Sprintf("copy [%s](%s) to [%s](%s)", srcAccount.GetAccount().VirtualPath, srcPath, dstAccount.GetAccount().VirtualPath, dstPath), func(t *task.Task) error { return CopyFileBetween2Accounts(t, srcAccount, dstAccount, srcPath, dstPath) diff --git a/internal/fs/put.go b/internal/fs/put.go index 6c2ce412029..0fde98fd009 100644 --- a/internal/fs/put.go +++ b/internal/fs/put.go @@ -18,7 +18,7 @@ func Put(ctx context.Context, account driver.Driver, parentPath string, file mod if err != nil { return errors.WithMessage(err, "failed get account") } - UploadTaskManager.Add(fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), account.GetAccount().VirtualPath, actualParentPath), func(task *task.Task) error { + UploadTaskManager.Submit(fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), account.GetAccount().VirtualPath, actualParentPath), func(task *task.Task) error { return operations.Put(task.Ctx, account, actualParentPath, file, nil) }) return nil diff --git a/pkg/task/errors.go b/pkg/task/errors.go new file mode 100644 index 00000000000..022a7cf9267 --- /dev/null +++ b/pkg/task/errors.go @@ -0,0 +1,7 @@ +package task + +import "errors" + +var ( + ErrTaskNotFound = errors.New("task not found") +) diff --git a/pkg/task/manager.go b/pkg/task/manager.go index de702737332..137b37984ef 100644 --- a/pkg/task/manager.go +++ b/pkg/task/manager.go @@ -1,25 +1,39 @@ package task import ( - "github.com/pkg/errors" + log "github.com/sirupsen/logrus" "sync/atomic" "github.com/alist-org/alist/v3/pkg/generic_sync" ) type Manager struct { - works uint - curID uint64 - tasks generic_sync.MapOf[uint64, *Task] + workerC chan struct{} + curID uint64 + tasks generic_sync.MapOf[uint64, *Task] } -func (tm *Manager) Add(name string, f Func) uint64 { +func (tm *Manager) Submit(name string, f Func) uint64 { task := newTask(name, f) tm.addTask(task) - go task.Run() + tm.do(task.ID) return task.ID } +func (tm *Manager) do(tid uint64) { + task := tm.MustGet(tid) + go func() { + log.Debugf("task [%s] waiting for worker", task.Name) + select { + case <-tm.workerC: + log.Debugf("task [%s] starting", task.Name) + task.run() + log.Debugf("task [%s] ended", task.Name) + } + tm.workerC <- struct{}{} + }() +} + func (tm *Manager) addTask(task *Task) { task.ID = tm.curID atomic.AddUint64(&tm.curID, 1) @@ -30,30 +44,35 @@ func (tm *Manager) GetAll() []*Task { return tm.tasks.Values() } -func (tm *Manager) Get(id uint64) (*Task, bool) { - return tm.tasks.Load(id) +func (tm *Manager) Get(tid uint64) (*Task, bool) { + return tm.tasks.Load(tid) +} + +func (tm *Manager) MustGet(tid uint64) *Task { + task, _ := tm.Get(tid) + return task } -func (tm *Manager) Retry(id uint64) error { - t, ok := tm.Get(id) +func (tm *Manager) Retry(tid uint64) error { + t, ok := tm.Get(tid) if !ok { - return errors.New("task not found") + return ErrTaskNotFound } - t.Retry() + tm.do(t.ID) return nil } -func (tm *Manager) Cancel(id uint64) error { - t, ok := tm.Get(id) +func (tm *Manager) Cancel(tid uint64) error { + t, ok := tm.Get(tid) if !ok { - return errors.New("task not found") + return ErrTaskNotFound } t.Cancel() return nil } -func (tm *Manager) Remove(id uint64) { - tm.tasks.Delete(id) +func (tm *Manager) Remove(tid uint64) { + tm.tasks.Delete(tid) } func (tm *Manager) RemoveFinished() { diff --git a/pkg/task/task.go b/pkg/task/task.go index 76d9f0091b0..cbf20c8ba5b 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -4,6 +4,7 @@ package task import ( "context" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" ) var ( @@ -12,6 +13,7 @@ var ( FINISHED = "finished" CANCELING = "canceling" CANCELED = "canceled" + ERRORED = "errored" ) type Func func(task *Task) error @@ -46,18 +48,27 @@ func (t *Task) SetProgress(percentage int) { t.Progress = percentage } -func (t *Task) Run() { +func (t *Task) run() { t.Status = RUNNING + defer func() { + if err := recover(); err != nil { + log.Errorf("error [%+v] while run task [%s]", err, t.Name) + t.Error = errors.Errorf("panic: %+v", err) + t.Status = ERRORED + } + }() t.Error = t.Func(t) if errors.Is(t.Ctx.Err(), context.Canceled) { t.Status = CANCELED + } else if t.Error != nil { + t.Status = ERRORED } else { t.Status = FINISHED } } -func (t *Task) Retry() { - t.Run() +func (t *Task) retry() { + t.run() } func (t *Task) Cancel() { diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 352781aa41c..84e1ca0b8e2 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -9,7 +9,7 @@ import ( func TestTask_Manager(t *testing.T) { tm := NewTaskManager() - id := tm.Add("test", func(task *Task) error { + id := tm.Submit("test", func(task *Task) error { time.Sleep(time.Millisecond * 500) return nil }) @@ -29,7 +29,7 @@ func TestTask_Manager(t *testing.T) { func TestTask_Cancel(t *testing.T) { tm := NewTaskManager() - id := tm.Add("test", func(task *Task) error { + id := tm.Submit("test", func(task *Task) error { for { if utils.IsCanceled(task.Ctx) { return nil @@ -53,7 +53,7 @@ func TestTask_Cancel(t *testing.T) { func TestTask_Retry(t *testing.T) { tm := NewTaskManager() num := 0 - id := tm.Add("test", func(task *Task) error { + id := tm.Submit("test", func(task *Task) error { num++ if num&1 == 1 { return errors.New("test error") @@ -71,7 +71,7 @@ func TestTask_Retry(t *testing.T) { } else { t.Logf("task error: %s", task.Error) } - task.Retry() + task.retry() time.Sleep(time.Millisecond) if task.Error != nil { t.Errorf("task error: %+v, but expected nil", task.Error)