diff --git a/pkg/payload/task_graph.go b/pkg/payload/task_graph.go index 8b685894b9..4637c7bbc9 100644 --- a/pkg/payload/task_graph.go +++ b/pkg/payload/task_graph.go @@ -5,14 +5,12 @@ import ( "fmt" "math/rand" "regexp" - "sort" "strings" "sync" - "k8s.io/klog" - "k8s.io/apimachinery/pkg/runtime/schema" utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/klog" ) // SplitOnJobs enforces the rule that any Job in the payload prevents reordering or parallelism (either before or after) @@ -426,180 +424,128 @@ type runTasks struct { } type taskStatus struct { - index int - success bool + index int + error error } // RunGraph executes the provided graph in order and in parallel up to maxParallelism. It will not start // a new TaskNode until all of the prerequisites have completed. If fn returns an error, no dependencies // of that node will be executed, but other indepedent edges will continue executing. func RunGraph(ctx context.Context, graph *TaskGraph, maxParallelism int, fn func(ctx context.Context, tasks []*Task) error) []error { - nestedCtx, cancelFn := context.WithCancel(ctx) - defer cancelFn() - - // This goroutine takes nodes from the graph as they are available (their prereq has completed) and - // sends them to workCh. It uses completeCh to know that a previously dispatched item is complete. - completeCh := make(chan taskStatus, maxParallelism) - defer close(completeCh) + submitted := make([]bool, len(graph.Nodes)) + results := make([]*taskStatus, len(graph.Nodes)) - workCh := make(chan runTasks, maxParallelism) - go func() { - defer close(workCh) - - // visited tracks nodes we have not sent (0), are currently - // waiting for completion (1), or have completed (2,3) - const ( - nodeNotVisited int = iota - nodeWorking - nodeFailed - nodeComplete - ) - visited := make([]int, len(graph.Nodes)) - canVisit := func(node *TaskNode) bool { - for _, previous := range node.In { - switch visited[previous] { - case nodeFailed, nodeWorking, nodeNotVisited: - return false - } + canVisit := func(node *TaskNode) bool { + for _, previous := range node.In { + if result := results[previous]; result == nil || result.error != nil { + return false } - return true } + return true + } - remaining := len(graph.Nodes) - var inflight int - for { - found := 0 - - // walk the graph, filling the work queue - for i := 0; i < len(visited); i++ { - if visited[i] != nodeNotVisited { - continue - } - if canVisit(graph.Nodes[i]) { - select { - case workCh <- runTasks{index: i, tasks: graph.Nodes[i].Tasks}: - visited[i] = nodeWorking - found++ - inflight++ - default: - break - } - } - } - - // try to empty the done channel - for len(completeCh) > 0 { - finished := <-completeCh - if finished.success { - visited[finished.index] = nodeComplete - } else { - visited[finished.index] = nodeFailed - } - remaining-- - inflight-- - found++ - } - - if found > 0 { + getNextNode := func() int { + for i, node := range graph.Nodes { + if submitted[i] { continue } - - // no more work to hand out - if remaining == 0 { - klog.V(4).Infof("Graph is complete") - return + if canVisit(node) { + return i } + } - // we walked the entire graph, there are still nodes remaining, but we're not waiting - // for anything - if inflight == 0 && found == 0 { - klog.V(4).Infof("No more reachable nodes in graph, continue") - break - } + return -1 + } - // we did nothing this round, so we have to wait for more - finished, ok := <-completeCh - if !ok { - // we've been aborted - klog.V(4).Infof("Stopped graph walker due to cancel") - return - } - if finished.success { - visited[finished.index] = nodeComplete - } else { - visited[finished.index] = nodeFailed - } - remaining-- - inflight-- - } + // Tasks go out to the workers via workCh, and results come brack + // from the workers via resultCh. + workCh := make(chan runTasks, maxParallelism) + defer close(workCh) - // take everything remaining and process in order - var unreachable []*Task - for i := 0; i < len(visited); i++ { - if visited[i] == nodeNotVisited && canVisit(graph.Nodes[i]) { - unreachable = append(unreachable, graph.Nodes[i].Tasks...) - } - } - if len(unreachable) > 0 { - sort.Slice(unreachable, func(i, j int) bool { - a, b := unreachable[i], unreachable[j] - return a.Index < b.Index - }) - workCh <- runTasks{index: -1, tasks: unreachable} - klog.V(4).Infof("Waiting for last tasks") - <-completeCh - } - klog.V(4).Infof("No more work") - }() + resultCh := make(chan taskStatus, maxParallelism) + defer close(resultCh) + + nestedCtx, cancelFn := context.WithCancel(ctx) + defer cancelFn() - errCh := make(chan error, maxParallelism) wg := sync.WaitGroup{} if maxParallelism < 1 { maxParallelism = 1 } for i := 0; i < maxParallelism; i++ { wg.Add(1) - go func(job int) { + go func(ctx context.Context, job int) { defer utilruntime.HandleCrash() defer wg.Done() for { select { - case <-nestedCtx.Done(): - klog.V(4).Infof("Canceled worker %d", job) + case <-ctx.Done(): + klog.V(4).Infof("Canceled worker %d while waiting for work", job) return - case runTask, ok := <-workCh: - if !ok { - klog.V(4).Infof("No more work for %d", job) - return - } + case runTask := <-workCh: klog.V(4).Infof("Running %d on worker %d", runTask.index, job) - err := fn(nestedCtx, runTask.tasks) - completeCh <- taskStatus{index: runTask.index, success: err == nil} - if err != nil { - errCh <- err - } + err := fn(ctx, runTask.tasks) + resultCh <- taskStatus{index: runTask.index, error: err} } } - }(i) + }(nestedCtx, i) + } + + var inflight int + nextNode := getNextNode() + done := false + for !done { + switch { + case ctx.Err() == nil && nextNode >= 0: // push a task or collect a result + select { + case workCh <- runTasks{index: nextNode, tasks: graph.Nodes[nextNode].Tasks}: + submitted[nextNode] = true + inflight++ + case result := <-resultCh: + results[result.index] = &result + inflight-- + case <-ctx.Done(): + } + case inflight > 0: // no work available to push; collect results + result := <-resultCh + results[result.index] = &result + inflight-- + default: // no work to push and nothing in flight. We're done + done = true + } + if !done { + nextNode = getNextNode() + } } - go func() { - klog.V(4).Infof("Waiting for workers to complete") - wg.Wait() - klog.V(4).Infof("Workers finished") - close(errCh) - }() + + cancelFn() + wg.Wait() + klog.V(4).Infof("Workers finished") var errs []error - for err := range errCh { - errs = append(errs, err) + var firstIncompleteNode *TaskNode + incompleteCount := 0 + for i, result := range results { + if result == nil { + if firstIncompleteNode == nil { + firstIncompleteNode = graph.Nodes[i] + } + incompleteCount++ + } else if result.error != nil { + errs = append(errs, result.error) + } + } + + if len(errs) == 0 && firstIncompleteNode != nil { + errs = append(errs, fmt.Errorf("%d incomplete task nodes, beginning with %s", incompleteCount, firstIncompleteNode.Tasks[0])) + if err := ctx.Err(); err != nil { + errs = append(errs, err) + } } + klog.V(4).Infof("Result of work: %v", errs) if len(errs) > 0 { return errs } - // if the context was cancelled, we may have unfinished work - if err := ctx.Err(); err != nil { - return []error{err} - } return nil } diff --git a/pkg/payload/task_graph_test.go b/pkg/payload/task_graph_test.go index e0230b4ec7..c3aee61915 100644 --- a/pkg/payload/task_graph_test.go +++ b/pkg/payload/task_graph_test.go @@ -704,7 +704,21 @@ func TestRunGraph(t *testing.T) { tasks := func(names ...string) []*Task { var arr []*Task for _, name := range names { - arr = append(arr, &Task{Manifest: &lib.Manifest{OriginalFilename: name}}) + manifest := &lib.Manifest{OriginalFilename: name} + err := manifest.UnmarshalJSON([]byte(fmt.Sprintf(` +{ + "apiVersion": "v1", + "kind": "ConfigMap", + "metadata": { + "name": "%s", + "namespace": "default" + } +} +`, name))) + if err != nil { + t.Fatalf("load %s: %v", name, err) + } + arr = append(arr, &Task{Manifest: manifest}) } return arr } @@ -862,7 +876,7 @@ func TestRunGraph(t *testing.T) { return nil }, want: []string{"a"}, - wantErrs: []string{"context canceled"}, + wantErrs: []string{`1 incomplete task nodes, beginning with configmap "default/b" (0 of 0)`, "context canceled"}, }, } for _, tt := range tests {