Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/payload/task_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,5 +589,9 @@ func RunGraph(ctx context.Context, graph *TaskGraph, maxParallelism int, fn func
if len(errs) > 0 {
return errs
}
// if the context was cancelled, we may have unfinished work
if err := ctx.Err(); err != nil {
return []error{err}
}
return nil
}
31 changes: 28 additions & 3 deletions pkg/payload/task_graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,14 @@ func TestRunGraph(t *testing.T) {
wantErrs []string
}{
{
name: "tasks executed in order",
nodes: []*TaskNode{
{Tasks: tasks("a", "b")},
},
order: []string{"a", "b"},
},
{
name: "nodes executed after dependencies",
nodes: []*TaskNode{
{Tasks: tasks("c"), In: []int{3}},
{Tasks: tasks("d", "e"), In: []int{3}},
Expand Down Expand Up @@ -644,6 +646,7 @@ func TestRunGraph(t *testing.T) {
},
},
{
name: "mid-task cancellation error interrupts node processing",
nodes: []*TaskNode{
{Tasks: tasks("c"), In: []int{2}},
{Tasks: tasks("d"), In: []int{2}, Out: []int{3}},
Expand All @@ -669,6 +672,7 @@ func TestRunGraph(t *testing.T) {
},
},
{
name: "task error interrupts node processing",
nodes: []*TaskNode{
{Tasks: tasks("c"), In: []int{2}},
{Tasks: tasks("d"), In: []int{2}, Out: []int{3}},
Expand All @@ -684,15 +688,15 @@ func TestRunGraph(t *testing.T) {
case <-time.After(time.Second):
t.Fatalf("expected context")
case <-ctx.Done():
t.Logf("got cancelled context")
return fmt.Errorf("cancelled")
t.Logf("got canceled context")
return ctx.Err()
}
return fmt.Errorf("error A")
}
return nil
},
want: []string{"a", "b", "c"},
wantErrs: []string{"cancelled"},
wantErrs: []string{"context canceled"},
invariants: func(t *testing.T, got []string) {
for _, s := range got {
if s == "e" {
Expand All @@ -702,6 +706,7 @@ func TestRunGraph(t *testing.T) {
},
},
{
name: "task errors in parallel nodes both reported",
nodes: []*TaskNode{
{Tasks: tasks("a"), Out: []int{1}},
{Tasks: tasks("b"), In: []int{0}, Out: []int{2, 4, 8}},
Expand All @@ -727,6 +732,26 @@ func TestRunGraph(t *testing.T) {
want: []string{"a", "b", "d1", "d2", "d3"},
wantErrs: []string{"error - c1", "error - f"},
},
{
name: "cancelation without task errors is reported",
nodes: []*TaskNode{
{Tasks: tasks("a"), Out: []int{1}},
{Tasks: tasks("b"), In: []int{0}},
},
sleep: time.Millisecond,
parallel: 1,
errorOn: func(t *testing.T, name string, ctx context.Context, cancelFn func()) error {
if name == "a" {
cancelFn()
time.Sleep(time.Second)
return nil
}
t.Fatalf("task b should never run")
return nil
},
want: []string{"a"},
wantErrs: []string{"context canceled"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down