Skip to content
12 changes: 3 additions & 9 deletions c_glib/arrow-glib/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,22 +1985,16 @@ garrow_execute_plan_validate(GArrowExecutePlan *plan,
/**
* garrow_execute_plan_start:
* @plan: A #GArrowExecutePlan.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Starts this plan.
*
* Returns: %TRUE on success, %FALSE on error.
*
* Since: 6.0.0
*/
gboolean
garrow_execute_plan_start(GArrowExecutePlan *plan,
GError **error)
void
garrow_execute_plan_start(GArrowExecutePlan *plan)
{
auto arrow_plan = garrow_execute_plan_get_raw(plan);
return garrow::check(error,
arrow_plan->StartProducing(),
"[execute-plan][start]");
arrow_plan->StartProducing();
}

/**
Expand Down
5 changes: 2 additions & 3 deletions c_glib/arrow-glib/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,8 @@ gboolean
garrow_execute_plan_validate(GArrowExecutePlan *plan,
GError **error);
GARROW_AVAILABLE_IN_6_0
gboolean
garrow_execute_plan_start(GArrowExecutePlan *plan,
GError **error);
void
garrow_execute_plan_start(GArrowExecutePlan *plan);
GARROW_AVAILABLE_IN_6_0
void
garrow_execute_plan_stop(GArrowExecutePlan *plan);
Expand Down
21 changes: 9 additions & 12 deletions cpp/examples/arrow/compute_register_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,11 @@ class ExampleNode : public cp::ExecNode {
ExampleNode(ExecNode* input, const ExampleNodeOptions&)
: ExecNode(/*plan=*/input->plan(), /*inputs=*/{input},
/*input_labels=*/{"ignored"},
/*output_schema=*/input->output_schema(), /*num_outputs=*/1) {}
/*output_schema=*/input->output_schema()) {}

const char* kind_name() const override { return "ExampleNode"; }

arrow::Status StartProducing() override {
outputs_[0]->InputFinished(this, 0);
return arrow::Status::OK();
}
arrow::Status StartProducing() override { return output_->InputFinished(this, 0); }

void ResumeProducing(ExecNode* output, int32_t counter) override {
inputs_[0]->ResumeProducing(this, counter);
Expand All @@ -90,14 +87,14 @@ class ExampleNode : public cp::ExecNode {
inputs_[0]->PauseProducing(this, counter);
}

void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); }
void StopProducing() override { inputs_[0]->StopProducing(); }
arrow::Status StopProducingImpl() override { return arrow::Status::OK(); }

void InputReceived(ExecNode* input, cp::ExecBatch batch) override {}
void ErrorReceived(ExecNode* input, arrow::Status error) override {}
void InputFinished(ExecNode* input, int total_batches) override {}

arrow::Future<> finished() override { return inputs_[0]->finished(); }
arrow::Status InputReceived(ExecNode* input, cp::ExecBatch batch) override {
return arrow::Status::OK();
}
arrow::Status InputFinished(ExecNode* input, int total_batches) override {
return arrow::Status::OK();
}
};

arrow::Result<cp::ExecNode*> ExampleExecNodeFactory(cp::ExecPlan* plan,
Expand Down
2 changes: 1 addition & 1 deletion cpp/examples/arrow/engine_substrait_consumption.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ arrow::Status RunSubstraitConsumer(int argc, char** argv) {

// Start the plan...
std::cout << std::string(50, '#') << " consuming batches:" << std::endl;
ARROW_RETURN_NOT_OK(plan->StartProducing());
plan->StartProducing();

// ... and wait for it to finish
ARROW_RETURN_NOT_OK(plan->finished().status());
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/arrow/execution_plan_documentation_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ arrow::Status ExecutePlanAndCollectAsTableWithCustomSink(
ARROW_RETURN_NOT_OK(plan->Validate());
std::cout << "ExecPlan created : " << plan->ToString() << std::endl;
// start the ExecPlan
ARROW_RETURN_NOT_OK(plan->StartProducing());
plan->StartProducing();

// collect sink_reader into a Table
std::shared_ptr<arrow::Table> response_table;
Expand Down Expand Up @@ -751,7 +751,7 @@ arrow::Status TableSinkExample() {
ARROW_RETURN_NOT_OK(plan->Validate());
std::cout << "ExecPlan created : " << plan->ToString() << std::endl;
// start the ExecPlan
ARROW_RETURN_NOT_OK(plan->StartProducing());
plan->StartProducing();

// Wait for the plan to finish
auto finished = plan->finished();
Expand Down
123 changes: 30 additions & 93 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<std::vector<std::unique_ptr<KernelState>>> states)
: ExecNode(plan, std::move(inputs), {"target"},
/*output_schema=*/std::move(output_schema),
/*num_outputs=*/1),
/*output_schema=*/std::move(output_schema)),
target_field_ids_(std::move(target_field_ids)),
aggs_(std::move(aggs)),
kernels_(std::move(kernels)),
Expand Down Expand Up @@ -155,36 +154,33 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
return Status::OK();
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
Status InputReceived(ExecNode* input, ExecBatch batch) override {
auto scope = TraceInputReceived(batch);
DCHECK_EQ(input, inputs_[0]);

auto thread_index = plan_->query_context()->GetThreadIndex();

if (ErrorIfNotOk(DoConsume(ExecSpan(batch), thread_index))) return;
ARROW_RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index));

if (input_counter_.Increment()) {
ErrorIfNotOk(Finish());
return Finish();
}
return Status::OK();
}

void ErrorReceived(ExecNode* input, Status error) override {
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
Status InputFinished(ExecNode* input, int total_batches) override {
EVENT_ON_CURRENT_SPAN("InputFinished", {{"batches.length", total_batches}});
DCHECK_EQ(input, inputs_[0]);
if (input_counter_.SetTotal(total_batches)) {
ErrorIfNotOk(Finish());
return Finish();
}
return Status::OK();
}

Status StartProducing() override {
NoteStartProducing(ToStringExtra());
// Scalar aggregates will only output a single batch
outputs_[0]->InputFinished(this, 1);
return Status::OK();
return output_->InputFinished(this, 1);
}

void PauseProducing(ExecNode* output, int32_t counter) override {
Expand All @@ -195,17 +191,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
inputs_[0]->ResumeProducing(this, counter);
}

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
}

void StopProducing() override {
if (input_counter_.Cancel()) {
finished_.MarkFinished();
}
inputs_[0]->StopProducing(this);
}
Status StopProducingImpl() override { return Status::OK(); }

protected:
std::string ToStringExtra(int indent = 0) const override {
Expand Down Expand Up @@ -234,9 +220,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode<ScalarAggregateNo
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
}

outputs_[0]->InputReceived(this, std::move(batch));
finished_.MarkFinished();
return Status::OK();
return output_->InputReceived(this, std::move(batch));
}

const std::vector<int> target_field_ids_;
Expand All @@ -254,23 +238,16 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
/*num_outputs=*/1),
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)),
key_field_ids_(std::move(key_field_ids)),
agg_src_field_ids_(std::move(agg_src_field_ids)),
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}

Status Init() override {
output_task_group_id_ = plan_->query_context()->RegisterTaskGroup(
[this](size_t, int64_t task_id) {
OutputNthBatch(task_id);
return Status::OK();
},
[this](size_t) {
finished_.MarkFinished();
return Status::OK();
});
[this](size_t, int64_t task_id) { return OutputNthBatch(task_id); },
[](size_t) { return Status::OK(); });
return Status::OK();
}

Expand Down Expand Up @@ -455,15 +432,13 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
return out_data;
}

void OutputNthBatch(int64_t n) {
// bail if StopProducing was called
if (finished_.is_finished()) return;

Status OutputNthBatch(int64_t n) {
int64_t batch_size = output_batch_size();
outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
return output_->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
}

Status DoOutputResult() {
Status OutputResult() {
auto scope = TraceFinish();
// To simplify merging, ensure that the first grouper is nonempty
for (size_t i = 0; i < local_states_.size(); i++) {
if (local_states_[i].grouper) {
Expand All @@ -476,60 +451,31 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());

int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size());
outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
Status st =
plan_->query_context()->StartTaskGroup(output_task_group_id_, num_output_batches);
if (st.IsCancelled()) {
// This means the user has cancelled/aborted the plan. We will not send any batches
// and end immediately.
finished_.MarkFinished();
return Status::OK();
} else {
return st;
}
return Status::OK();
RETURN_NOT_OK(output_->InputFinished(this, static_cast<int>(num_output_batches)));
return plan_->query_context()->StartTaskGroup(output_task_group_id_,
num_output_batches);
}

void OutputResult() {
auto scope = TraceFinish();
// If something goes wrong outputting the result we need to make sure
// we still mark finished.
Status st = DoOutputResult();
if (!st.ok()) {
finished_.MarkFinished(st);
}
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
Status InputReceived(ExecNode* input, ExecBatch batch) override {
auto scope = TraceInputReceived(batch);

// bail if StopProducing was called
if (finished_.is_finished()) return;

DCHECK_EQ(input, inputs_[0]);

if (ErrorIfNotOk(Consume(ExecSpan(batch)))) return;
ARROW_RETURN_NOT_OK(Consume(ExecSpan(batch)));

if (input_counter_.Increment()) {
OutputResult();
return OutputResult();
}
return Status::OK();
}

void ErrorReceived(ExecNode* input, Status error) override {
DCHECK_EQ(input, inputs_[0]);

outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
// bail if StopProducing was called
if (finished_.is_finished()) return;

Status InputFinished(ExecNode* input, int total_batches) override {
DCHECK_EQ(input, inputs_[0]);

if (input_counter_.SetTotal(total_batches)) {
OutputResult();
return OutputResult();
}
return Status::OK();
}

Status StartProducing() override {
Expand All @@ -548,16 +494,7 @@ class GroupByNode : public ExecNode, public TracedNode<GroupByNode> {
// Without spillover there is way to handle backpressure in this node
}

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);

if (input_counter_.Cancel()) {
finished_.MarkFinished();
}
inputs_[0]->StopProducing(this);
}

void StopProducing() override { StopProducing(outputs_[0]); }
Status StopProducingImpl() override { return Status::OK(); }

protected:
std::string ToStringExtra(int indent = 0) const override {
Expand Down
Loading