diff --git a/pkg/payload/task_graph.go b/pkg/payload/task_graph.go index 4976182842..c0106a4eb7 100644 --- a/pkg/payload/task_graph.go +++ b/pkg/payload/task_graph.go @@ -612,5 +612,9 @@ func RunGraph(ctx context.Context, graph *TaskGraph, maxParallelism int, fn func if len(errs) > 0 { return errs } + // if the context was cancelled, we may have unfinished work + if err := ctx.Err(); err != nil { + return []error{err} + } return nil } diff --git a/pkg/payload/task_graph_test.go b/pkg/payload/task_graph_test.go index b1a6d503c4..3ae17549b5 100644 --- a/pkg/payload/task_graph_test.go +++ b/pkg/payload/task_graph_test.go @@ -721,12 +721,14 @@ func TestRunGraph(t *testing.T) { wantErrs []string }{ { + name: "tasks executed in order", nodes: []*TaskNode{ {Tasks: tasks("a", "b")}, }, order: []string{"a", "b"}, }, { + name: "nodes executed after dependencies", nodes: []*TaskNode{ {Tasks: tasks("c"), In: []int{3}}, {Tasks: tasks("d", "e"), In: []int{3}}, @@ -756,6 +758,7 @@ func TestRunGraph(t *testing.T) { }, }, { + name: "task error interrupts node processing", nodes: []*TaskNode{ {Tasks: tasks("c"), In: []int{2}}, {Tasks: tasks("d"), In: []int{2}, Out: []int{3}}, @@ -781,6 +784,7 @@ func TestRunGraph(t *testing.T) { }, }, { + name: "mid-task cancellation error interrupts node processing", nodes: []*TaskNode{ {Tasks: tasks("c"), In: []int{2}}, {Tasks: tasks("d"), In: []int{2}, Out: []int{3}}, @@ -796,15 +800,15 @@ func TestRunGraph(t *testing.T) { case <-time.After(time.Second): t.Fatalf("expected context") case <-ctx.Done(): - t.Logf("got cancelled context") - return fmt.Errorf("cancelled") + t.Logf("got canceled context") + return ctx.Err() } return fmt.Errorf("error A") } return nil }, want: []string{"a", "b", "c"}, - wantErrs: []string{"cancelled"}, + wantErrs: []string{"context canceled"}, invariants: func(t *testing.T, got []string) { for _, s := range got { if s == "e" { @@ -814,6 +818,7 @@ func TestRunGraph(t *testing.T) { }, }, { + name: "task errors in parallel nodes both reported", nodes: []*TaskNode{ {Tasks: tasks("a"), Out: []int{1}}, {Tasks: tasks("b"), In: []int{0}, Out: []int{2, 4, 8}}, @@ -839,6 +844,26 @@ func TestRunGraph(t *testing.T) { want: []string{"a", "b", "d1", "d2", "d3"}, wantErrs: []string{"error - c1", "error - f"}, }, + { + name: "cancelation without task errors is reported", + nodes: []*TaskNode{ + {Tasks: tasks("a"), Out: []int{1}}, + {Tasks: tasks("b"), In: []int{0}}, + }, + sleep: time.Millisecond, + parallel: 1, + errorOn: func(t *testing.T, name string, ctx context.Context, cancelFn func()) error { + if name == "a" { + cancelFn() + time.Sleep(time.Second) + return nil + } + t.Fatalf("task b should never run") + return nil + }, + want: []string{"a"}, + wantErrs: []string{"context canceled"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {