Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions c_glib/arrow-glib/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1914,18 +1914,18 @@ garrow_execute_plan_start(GArrowExecutePlan *plan,
}

/**
* garrow_execute_plan_stop:
* garrow_execute_plan_abort:
* @plan: A #GArrowExecutePlan.
*
* Stops this plan.
*
* Since: 6.0.0
*/
void
garrow_execute_plan_stop(GArrowExecutePlan *plan)
garrow_execute_plan_abort(GArrowExecutePlan *plan)
{
auto arrow_plan = garrow_execute_plan_get_raw(plan);
arrow_plan->StopProducing();
arrow_plan->Abort();
}

/**
Expand Down
21 changes: 10 additions & 11 deletions cpp/examples/arrow/compute_register_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,29 @@ class ExampleNode : public cp::ExecNode {
ExampleNode(ExecNode* input, const ExampleNodeOptions&)
: ExecNode(/*plan=*/input->plan(), /*inputs=*/{input},
/*input_labels=*/{"ignored"},
/*output_schema=*/input->output_schema(), /*num_outputs=*/1) {}
/*output_schema=*/input->output_schema()) {}

const char* kind_name() const override { return "ExampleNode"; }

arrow::Status StartProducing() override {
outputs_[0]->InputFinished(this, 0);
return arrow::Status::OK();
}
arrow::Status StartProducing() override { return output_->InputFinished(this, 0); }

void ResumeProducing(ExecNode* output, int32_t counter) override {
inputs_[0]->ResumeProducing(this, counter);
}

void PauseProducing(ExecNode* output, int32_t counter) override {
inputs_[0]->PauseProducing(this, counter);
}

void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); }
void StopProducing() override { inputs_[0]->StopProducing(); }
arrow::Status InputReceived(ExecNode* input, cp::ExecBatch batch) override {
return arrow::Status::OK();
}

void InputReceived(ExecNode* input, cp::ExecBatch batch) override {}
void ErrorReceived(ExecNode* input, arrow::Status error) override {}
void InputFinished(ExecNode* input, int total_batches) override {}
arrow::Status InputFinished(ExecNode* input, int total_batches) override {
return arrow::Status::OK();
}

arrow::Future<> finished() override { return inputs_[0]->finished(); }
void Abort() override {}
};

arrow::Result<cp::ExecNode*> ExampleExecNodeFactory(cp::ExecPlan* plan,
Expand Down
1 change: 0 additions & 1 deletion cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ if(ARROW_COMPUTE)
compute/exec/hash_join_node.cc
compute/exec/key_hash.cc
compute/exec/key_map.cc
compute/exec/map_node.cc
compute/exec/order_by_impl.cc
compute/exec/partition_util.cc
compute/exec/options.cc
Expand Down
106 changes: 28 additions & 78 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ class ScalarAggregateNode : public ExecNode {
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<std::vector<std::unique_ptr<KernelState>>> states)
: ExecNode(plan, std::move(inputs), {"target"},
/*output_schema=*/std::move(output_schema),
/*num_outputs=*/1),
/*output_schema=*/std::move(output_schema)),
target_field_ids_(std::move(target_field_ids)),
aggs_(std::move(aggs)),
kernels_(std::move(kernels)),
Expand Down Expand Up @@ -159,7 +158,7 @@ class ScalarAggregateNode : public ExecNode {
return Status::OK();
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
Status InputReceived(ExecNode* input, ExecBatch batch) override {
EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
util::tracing::Span span;
START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived",
Expand All @@ -170,25 +169,21 @@ class ScalarAggregateNode : public ExecNode {

auto thread_index = plan_->GetThreadIndex();

if (ErrorIfNotOk(DoConsume(ExecSpan(batch), thread_index))) return;
RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index));

if (input_counter_.Increment()) {
ErrorIfNotOk(Finish());
return Finish();
}
return Status::OK();
}

void ErrorReceived(ExecNode* input, Status error) override {
EVENT(span_, "ErrorReceived", {{"error", error.message()}});
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
Status InputFinished(ExecNode* input, int total_batches) override {
EVENT(span_, "InputFinished", {{"batches.length", total_batches}});
DCHECK_EQ(input, inputs_[0]);
if (input_counter_.SetTotal(total_batches)) {
ErrorIfNotOk(Finish());
return Finish();
}
return Status::OK();
}

Status StartProducing() override {
Expand All @@ -197,8 +192,7 @@ class ScalarAggregateNode : public ExecNode {
{"node.detail", ToString()},
{"node.kind", kind_name()}});
// Scalar aggregates will only output a single batch
outputs_[0]->InputFinished(this, 1);
return Status::OK();
return output_->InputFinished(this, 1);
}

void PauseProducing(ExecNode* output, int32_t counter) override {
Expand All @@ -209,18 +203,7 @@ class ScalarAggregateNode : public ExecNode {
inputs_[0]->ResumeProducing(this, counter);
}

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
}

void StopProducing() override {
EVENT(span_, "StopProducing");
if (input_counter_.Cancel()) {
finished_.MarkFinished();
}
inputs_[0]->StopProducing(this);
}
void Abort() override {}

protected:
std::string ToStringExtra(int indent = 0) const override {
Expand Down Expand Up @@ -251,9 +234,7 @@ class ScalarAggregateNode : public ExecNode {
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
}

outputs_[0]->InputReceived(this, std::move(batch));
finished_.MarkFinished();
return Status::OK();
return output_->InputReceived(this, std::move(batch));
}

const std::vector<int> target_field_ids_;
Expand All @@ -271,24 +252,19 @@ class GroupByNode : public ExecNode {
std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
/*num_outputs=*/1),
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)),
ctx_(ctx),
key_field_ids_(std::move(key_field_ids)),
agg_src_field_ids_(std::move(agg_src_field_ids)),
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}

Status Init() override {
RETURN_NOT_OK(ExecNode::Init());
output_task_group_id_ = plan_->RegisterTaskGroup(
[this](size_t, int64_t task_id) {
OutputNthBatch(task_id);
return Status::OK();
},
[this](size_t) {
finished_.MarkFinished();
return Status::OK();
});
[this](size_t, int64_t task_id) { return OutputNthBatch(task_id); },
[](size_t) { return Status::OK(); });
local_states_.resize(plan_->max_concurrency());
return Status::OK();
}

Expand Down Expand Up @@ -475,12 +451,9 @@ class GroupByNode : public ExecNode {
return out_data;
}

void OutputNthBatch(int64_t n) {
// bail if StopProducing was called
if (finished_.is_finished()) return;

Status OutputNthBatch(int64_t n) {
int64_t batch_size = output_batch_size();
outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
return output_->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
}

Status OutputResult() {
Expand All @@ -496,50 +469,36 @@ class GroupByNode : public ExecNode {
ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());

int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size());
outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_, num_output_batches));
return Status::OK();
RETURN_NOT_OK(output_->InputFinished(this, static_cast<int>(num_output_batches)));
return plan_->StartTaskGroup(output_task_group_id_, num_output_batches);
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
Status InputReceived(ExecNode* input, ExecBatch batch) override {
EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
util::tracing::Span span;
START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived",
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});

// bail if StopProducing was called
if (finished_.is_finished()) return;

DCHECK_EQ(input, inputs_[0]);

if (ErrorIfNotOk(Consume(ExecSpan(batch)))) return;

RETURN_NOT_OK(Consume(ExecSpan(batch)));
if (input_counter_.Increment()) {
ErrorIfNotOk(OutputResult());
return OutputResult();
}
return Status::OK();
}

void ErrorReceived(ExecNode* input, Status error) override {
EVENT(span_, "ErrorReceived", {{"error", error.message()}});

DCHECK_EQ(input, inputs_[0]);

outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
Status InputFinished(ExecNode* input, int total_batches) override {
EVENT(span_, "InputFinished", {{"batches.length", total_batches}});

// bail if StopProducing was called
if (finished_.is_finished()) return;

DCHECK_EQ(input, inputs_[0]);

if (input_counter_.SetTotal(total_batches)) {
ErrorIfNotOk(OutputResult());
return OutputResult();
}
return Status::OK();
}

Status StartProducing() override {
Expand All @@ -548,7 +507,6 @@ class GroupByNode : public ExecNode {
{"node.detail", ToString()},
{"node.kind", kind_name()}});

local_states_.resize(plan_->max_concurrency());
return Status::OK();
}

Expand All @@ -562,15 +520,7 @@ class GroupByNode : public ExecNode {
// Without spillover there is way to handle backpressure in this node
}

void StopProducing(ExecNode* output) override {
EVENT(span_, "StopProducing");
DCHECK_EQ(output, outputs_[0]);

if (input_counter_.Cancel()) finished_.MarkFinished();
inputs_[0]->StopProducing(this);
}

void StopProducing() override { StopProducing(outputs_[0]); }
void Abort() override { input_counter_.Cancel(); }

protected:
std::string ToStringExtra(int indent = 0) const override {
Expand Down
Loading