diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 0b70577ae78..725372700ca 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -486,7 +486,7 @@ class GroupByNode : public ExecNode { outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size)); } - Status OutputResult() { + Status DoOutputResult() { // To simplify merging, ensure that the first grouper is nonempty for (size_t i = 0; i < local_states_.size(); i++) { if (local_states_[i].grouper) { @@ -500,11 +500,28 @@ class GroupByNode : public ExecNode { int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size()); outputs_[0]->InputFinished(this, static_cast(num_output_batches)); - RETURN_NOT_OK(plan_->query_context()->StartTaskGroup(output_task_group_id_, - num_output_batches)); + Status st = + plan_->query_context()->StartTaskGroup(output_task_group_id_, num_output_batches); + if (st.IsCancelled()) { + // This means the user has cancelled/aborted the plan. We will not send any batches + // and end immediately. + finished_.MarkFinished(); + return Status::OK(); + } else { + return st; + } return Status::OK(); } + void OutputResult() { + // If something goes wrong outputting the result we need to make sure + // we still mark finished. + Status st = DoOutputResult(); + if (!st.ok()) { + finished_.MarkFinished(st); + } + } + void InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); util::tracing::Span span; @@ -521,7 +538,7 @@ class GroupByNode : public ExecNode { if (ErrorIfNotOk(Consume(ExecSpan(batch)))) return; if (input_counter_.Increment()) { - ErrorIfNotOk(OutputResult()); + OutputResult(); } } @@ -542,7 +559,7 @@ class GroupByNode : public ExecNode { DCHECK_EQ(input, inputs_[0]); if (input_counter_.SetTotal(total_batches)) { - ErrorIfNotOk(OutputResult()); + OutputResult(); } } @@ -551,7 +568,6 @@ class GroupByNode : public ExecNode { {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", kind_name()}}); - local_states_.resize(plan_->query_context()->max_concurrency()); return Status::OK(); } @@ -570,7 +586,9 @@ class GroupByNode : public ExecNode { EVENT(span_, "StopProducing"); DCHECK_EQ(output, outputs_[0]); - if (input_counter_.Cancel()) finished_.MarkFinished(); + if (input_counter_.Cancel()) { + finished_.MarkFinished(); + } inputs_[0]->StopProducing(this); }