Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 85 additions & 139 deletions pkg/payload/task_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"fmt"
"math/rand"
"regexp"
"sort"
"strings"
"sync"

"k8s.io/klog"

"k8s.io/apimachinery/pkg/runtime/schema"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/klog"
)

// SplitOnJobs enforces the rule that any Job in the payload prevents reordering or parallelism (either before or after)
Expand Down Expand Up @@ -426,180 +424,128 @@ type runTasks struct {
}

type taskStatus struct {
index int
success bool
index int
error error
}

// RunGraph executes the provided graph in order and in parallel up to maxParallelism. It will not start
// a new TaskNode until all of the prerequisites have completed. If fn returns an error, no dependencies
// of that node will be executed, but other indepedent edges will continue executing.
func RunGraph(ctx context.Context, graph *TaskGraph, maxParallelism int, fn func(ctx context.Context, tasks []*Task) error) []error {
nestedCtx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

// This goroutine takes nodes from the graph as they are available (their prereq has completed) and
// sends them to workCh. It uses completeCh to know that a previously dispatched item is complete.
completeCh := make(chan taskStatus, maxParallelism)
defer close(completeCh)
submitted := make([]bool, len(graph.Nodes))
results := make([]*taskStatus, len(graph.Nodes))

workCh := make(chan runTasks, maxParallelism)
go func() {
defer close(workCh)

// visited tracks nodes we have not sent (0), are currently
// waiting for completion (1), or have completed (2,3)
const (
nodeNotVisited int = iota
nodeWorking
nodeFailed
nodeComplete
)
visited := make([]int, len(graph.Nodes))
canVisit := func(node *TaskNode) bool {
for _, previous := range node.In {
switch visited[previous] {
case nodeFailed, nodeWorking, nodeNotVisited:
return false
}
canVisit := func(node *TaskNode) bool {
for _, previous := range node.In {
if result := results[previous]; result == nil || result.error != nil {
return false
}
return true
}
return true
}

remaining := len(graph.Nodes)
var inflight int
for {
found := 0

// walk the graph, filling the work queue
for i := 0; i < len(visited); i++ {
if visited[i] != nodeNotVisited {
continue
}
if canVisit(graph.Nodes[i]) {
select {
case workCh <- runTasks{index: i, tasks: graph.Nodes[i].Tasks}:
visited[i] = nodeWorking
found++
inflight++
default:
break
}
}
}

// try to empty the done channel
for len(completeCh) > 0 {
finished := <-completeCh
if finished.success {
visited[finished.index] = nodeComplete
} else {
visited[finished.index] = nodeFailed
}
remaining--
inflight--
found++
}

if found > 0 {
getNextNode := func() int {
for i, node := range graph.Nodes {
if submitted[i] {
continue
}

// no more work to hand out
if remaining == 0 {
klog.V(4).Infof("Graph is complete")
return
if canVisit(node) {
return i
}
}

// we walked the entire graph, there are still nodes remaining, but we're not waiting
// for anything
if inflight == 0 && found == 0 {
klog.V(4).Infof("No more reachable nodes in graph, continue")
break
}
return -1
}

// we did nothing this round, so we have to wait for more
finished, ok := <-completeCh
if !ok {
// we've been aborted
klog.V(4).Infof("Stopped graph walker due to cancel")
return
}
if finished.success {
visited[finished.index] = nodeComplete
} else {
visited[finished.index] = nodeFailed
}
remaining--
inflight--
}
// Tasks go out to the workers via workCh, and results come brack
// from the workers via resultCh.
workCh := make(chan runTasks, maxParallelism)
defer close(workCh)

// take everything remaining and process in order
var unreachable []*Task
for i := 0; i < len(visited); i++ {
if visited[i] == nodeNotVisited && canVisit(graph.Nodes[i]) {
unreachable = append(unreachable, graph.Nodes[i].Tasks...)
}
}
if len(unreachable) > 0 {
sort.Slice(unreachable, func(i, j int) bool {
a, b := unreachable[i], unreachable[j]
return a.Index < b.Index
})
workCh <- runTasks{index: -1, tasks: unreachable}
klog.V(4).Infof("Waiting for last tasks")
<-completeCh
}
klog.V(4).Infof("No more work")
}()
resultCh := make(chan taskStatus, maxParallelism)
defer close(resultCh)

nestedCtx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

errCh := make(chan error, maxParallelism)
wg := sync.WaitGroup{}
if maxParallelism < 1 {
maxParallelism = 1
}
for i := 0; i < maxParallelism; i++ {
wg.Add(1)
go func(job int) {
go func(ctx context.Context, job int) {
defer utilruntime.HandleCrash()
defer wg.Done()
for {
select {
case <-nestedCtx.Done():
klog.V(4).Infof("Canceled worker %d", job)
case <-ctx.Done():
klog.V(4).Infof("Canceled worker %d while waiting for work", job)
return
case runTask, ok := <-workCh:
if !ok {
klog.V(4).Infof("No more work for %d", job)
return
}
case runTask := <-workCh:
klog.V(4).Infof("Running %d on worker %d", runTask.index, job)
err := fn(nestedCtx, runTask.tasks)
completeCh <- taskStatus{index: runTask.index, success: err == nil}
if err != nil {
errCh <- err
}
err := fn(ctx, runTask.tasks)
resultCh <- taskStatus{index: runTask.index, error: err}
}
}
}(i)
}(nestedCtx, i)
}

var inflight int
nextNode := getNextNode()
done := false
for !done {
switch {
case ctx.Err() == nil && nextNode >= 0: // push a task or collect a result
select {
case workCh <- runTasks{index: nextNode, tasks: graph.Nodes[nextNode].Tasks}:
submitted[nextNode] = true
inflight++
case result := <-resultCh:
results[result.index] = &result
inflight--
case <-ctx.Done():
}
case inflight > 0: // no work available to push; collect results
result := <-resultCh
results[result.index] = &result
inflight--
default: // no work to push and nothing in flight. We're done
done = true
}
if !done {
nextNode = getNextNode()
}
}
go func() {
klog.V(4).Infof("Waiting for workers to complete")
wg.Wait()
klog.V(4).Infof("Workers finished")
close(errCh)
}()

cancelFn()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the runner is

  • waiting on ctx
  • waiting on work from workCh

the also passes along the ctx to the syn task fn.

maybe what we should do is,
the runner only loops on the workCh for work, and passes along the ctx to task sync fn so that we can terminate in progress work.
here we should just close workCh as there is no longer any work left to done.

that simplifies/compartmentalizes what's used for what.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the Go spec:

Sending to or closing a closed channel causes a run-time panic.

So if we manually call close(workCh) here, then this deferred call will panic, right? Maybe Go ignores panics from deferred functions, but still, I personally prefer having the worker watching both nestedCtx and workCh.

wg.Wait()
klog.V(4).Infof("Workers finished")

var errs []error
for err := range errCh {
errs = append(errs, err)
var firstIncompleteNode *TaskNode
incompleteCount := 0
for i, result := range results {
if result == nil {
if firstIncompleteNode == nil {
firstIncompleteNode = graph.Nodes[i]
}
incompleteCount++
} else if result.error != nil {
errs = append(errs, result.error)
}
}

if len(errs) == 0 && firstIncompleteNode != nil {
errs = append(errs, fmt.Errorf("%d incomplete task nodes, beginning with %s", incompleteCount, firstIncompleteNode.Tasks[0]))
if err := ctx.Err(); err != nil {
errs = append(errs, err)
}
}

klog.V(4).Infof("Result of work: %v", errs)
if len(errs) > 0 {
return errs
}
// if the context was cancelled, we may have unfinished work
if err := ctx.Err(); err != nil {
return []error{err}
}
return nil
}
18 changes: 16 additions & 2 deletions pkg/payload/task_graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,21 @@ func TestRunGraph(t *testing.T) {
tasks := func(names ...string) []*Task {
var arr []*Task
for _, name := range names {
arr = append(arr, &Task{Manifest: &lib.Manifest{OriginalFilename: name}})
manifest := &lib.Manifest{OriginalFilename: name}
err := manifest.UnmarshalJSON([]byte(fmt.Sprintf(`
{
"apiVersion": "v1",
"kind": "ConfigMap",
"metadata": {
"name": "%s",
"namespace": "default"
}
}
`, name)))
if err != nil {
t.Fatalf("load %s: %v", name, err)
}
arr = append(arr, &Task{Manifest: manifest})
}
return arr
}
Expand Down Expand Up @@ -862,7 +876,7 @@ func TestRunGraph(t *testing.T) {
return nil
},
want: []string{"a"},
wantErrs: []string{"context canceled"},
wantErrs: []string{`1 incomplete task nodes, beginning with configmap "default/b" (0 of 0)`, "context canceled"},
},
}
for _, tt := range tests {
Expand Down