diff --git a/pkg/payload/task_graph.go b/pkg/payload/task_graph.go index 4637c7bbc9..079b243a68 100644 --- a/pkg/payload/task_graph.go +++ b/pkg/payload/task_graph.go @@ -507,9 +507,18 @@ func RunGraph(ctx context.Context, graph *TaskGraph, maxParallelism int, fn func case <-ctx.Done(): } case inflight > 0: // no work available to push; collect results - result := <-resultCh - results[result.index] = &result - inflight-- + select { + case result := <-resultCh: + results[result.index] = &result + inflight-- + case <-ctx.Done(): + select { + case runTask := <-workCh: // workers canceled, so remove any work from the queue ourselves + inflight-- + submitted[runTask.index] = false + default: + } + } default: // no work to push and nothing in flight. We're done done = true } diff --git a/pkg/payload/task_graph_test.go b/pkg/payload/task_graph_test.go index c3aee61915..6b484c4092 100644 --- a/pkg/payload/task_graph_test.go +++ b/pkg/payload/task_graph_test.go @@ -831,6 +831,26 @@ func TestRunGraph(t *testing.T) { } }, }, + { + name: "mid-task cancellation with work in queue does not deadlock", + nodes: []*TaskNode{ + {Tasks: tasks("a1", "a2", "a3")}, + {Tasks: tasks("b")}, + }, + sleep: time.Millisecond, + parallel: 1, + errorOn: func(t *testing.T, name string, ctx context.Context, cancelFn func()) error { + if err := ctx.Err(); err != nil { + return err + } + if name == "a2" { + cancelFn() + } + return nil + }, + want: []string{"a1", "a2"}, + wantErrs: []string{"context canceled"}, + }, { name: "task errors in parallel nodes both reported", nodes: []*TaskNode{