From 1c75db4507a250018f0ec24428d21e6eded0853e Mon Sep 17 00:00:00 2001 From: Sasha Krassovsky Date: Wed, 3 Aug 2022 18:55:42 -0700 Subject: [PATCH] Part 2 of refactor --- c_glib/arrow-glib/compute.cpp | 6 +- .../arrow/compute_register_example.cc | 21 +- cpp/src/arrow/CMakeLists.txt | 1 - cpp/src/arrow/compute/exec/aggregate_node.cc | 106 ++----- cpp/src/arrow/compute/exec/asof_join_node.cc | 72 ++--- cpp/src/arrow/compute/exec/benchmark_util.cc | 9 +- cpp/src/arrow/compute/exec/exec_plan.cc | 262 ++++++------------ cpp/src/arrow/compute/exec/exec_plan.h | 51 +--- cpp/src/arrow/compute/exec/filter_node.cc | 57 ++-- cpp/src/arrow/compute/exec/hash_join.cc | 22 +- cpp/src/arrow/compute/exec/hash_join.h | 4 +- .../arrow/compute/exec/hash_join_benchmark.cc | 5 +- cpp/src/arrow/compute/exec/hash_join_node.cc | 84 ++---- .../arrow/compute/exec/hash_join_node_test.cc | 2 +- cpp/src/arrow/compute/exec/map_node.cc | 122 -------- cpp/src/arrow/compute/exec/map_node.h | 75 ----- cpp/src/arrow/compute/exec/plan_test.cc | 206 +++++++------- cpp/src/arrow/compute/exec/project_node.cc | 49 ++-- cpp/src/arrow/compute/exec/sink_node.cc | 170 ++++-------- cpp/src/arrow/compute/exec/source_node.cc | 82 ++++-- cpp/src/arrow/compute/exec/swiss_join.cc | 22 +- cpp/src/arrow/compute/exec/swiss_join.h | 6 +- cpp/src/arrow/compute/exec/test_util.cc | 52 ++-- cpp/src/arrow/compute/exec/test_util.h | 4 +- cpp/src/arrow/compute/exec/tpch_node.cc | 89 ++---- cpp/src/arrow/compute/exec/union_node.cc | 64 +---- .../compute/kernels/hash_aggregate_test.cc | 2 +- cpp/src/arrow/dataset/file_base.cc | 68 +++-- cpp/src/arrow/dataset/scan_node.cc | 37 +-- cpp/src/arrow/dataset/scanner.cc | 12 +- cpp/src/arrow/dataset/scanner_test.cc | 12 +- .../arrow/engine/substrait/function_test.cc | 4 +- cpp/src/arrow/engine/substrait/serde_test.cc | 3 +- cpp/src/arrow/util/async_util.cc | 4 +- cpp/src/arrow/util/async_util.h | 3 - 35 files changed, 628 insertions(+), 1160 deletions(-) delete mode 100644 cpp/src/arrow/compute/exec/map_node.cc delete mode 100644 cpp/src/arrow/compute/exec/map_node.h diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 3554fdf1158..0506c10847f 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -1914,7 +1914,7 @@ garrow_execute_plan_start(GArrowExecutePlan *plan, } /** - * garrow_execute_plan_stop: + * garrow_execute_plan_abort: * @plan: A #GArrowExecutePlan. * * Stops this plan. @@ -1922,10 +1922,10 @@ garrow_execute_plan_start(GArrowExecutePlan *plan, * Since: 6.0.0 */ void -garrow_execute_plan_stop(GArrowExecutePlan *plan) +garrow_execute_plan_abort(GArrowExecutePlan *plan) { auto arrow_plan = garrow_execute_plan_get_raw(plan); - arrow_plan->StopProducing(); + arrow_plan->Abort(); } /** diff --git a/cpp/examples/arrow/compute_register_example.cc b/cpp/examples/arrow/compute_register_example.cc index 1b96dd42220..9af3057ec82 100644 --- a/cpp/examples/arrow/compute_register_example.cc +++ b/cpp/examples/arrow/compute_register_example.cc @@ -74,30 +74,29 @@ class ExampleNode : public cp::ExecNode { ExampleNode(ExecNode* input, const ExampleNodeOptions&) : ExecNode(/*plan=*/input->plan(), /*inputs=*/{input}, /*input_labels=*/{"ignored"}, - /*output_schema=*/input->output_schema(), /*num_outputs=*/1) {} + /*output_schema=*/input->output_schema()) {} const char* kind_name() const override { return "ExampleNode"; } - arrow::Status StartProducing() override { - outputs_[0]->InputFinished(this, 0); - return arrow::Status::OK(); - } + arrow::Status StartProducing() override { return output_->InputFinished(this, 0); } void ResumeProducing(ExecNode* output, int32_t counter) override { inputs_[0]->ResumeProducing(this, counter); } + void PauseProducing(ExecNode* output, int32_t counter) override { inputs_[0]->PauseProducing(this, counter); } - void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); } - void StopProducing() override { inputs_[0]->StopProducing(); } + arrow::Status InputReceived(ExecNode* input, cp::ExecBatch batch) override { + return arrow::Status::OK(); + } - void InputReceived(ExecNode* input, cp::ExecBatch batch) override {} - void ErrorReceived(ExecNode* input, arrow::Status error) override {} - void InputFinished(ExecNode* input, int total_batches) override {} + arrow::Status InputFinished(ExecNode* input, int total_batches) override { + return arrow::Status::OK(); + } - arrow::Future<> finished() override { return inputs_[0]->finished(); } + void Abort() override {} }; arrow::Result ExampleExecNodeFactory(cp::ExecPlan* plan, diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 23f0a7c9f1a..88d72b11832 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -397,7 +397,6 @@ if(ARROW_COMPUTE) compute/exec/hash_join_node.cc compute/exec/key_hash.cc compute/exec/key_map.cc - compute/exec/map_node.cc compute/exec/order_by_impl.cc compute/exec/partition_util.cc compute/exec/options.cc diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index cca266ad691..2af7fedbf4e 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -68,8 +68,7 @@ class ScalarAggregateNode : public ExecNode { std::vector kernels, std::vector>> states) : ExecNode(plan, std::move(inputs), {"target"}, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), + /*output_schema=*/std::move(output_schema)), target_field_ids_(std::move(target_field_ids)), aggs_(std::move(aggs)), kernels_(std::move(kernels)), @@ -159,7 +158,7 @@ class ScalarAggregateNode : public ExecNode { return Status::OK(); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); util::tracing::Span span; START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived", @@ -170,25 +169,21 @@ class ScalarAggregateNode : public ExecNode { auto thread_index = plan_->GetThreadIndex(); - if (ErrorIfNotOk(DoConsume(ExecSpan(batch), thread_index))) return; + RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index)); if (input_counter_.Increment()) { - ErrorIfNotOk(Finish()); + return Finish(); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - EVENT(span_, "ErrorReceived", {{"error", error.message()}}); - DCHECK_EQ(input, inputs_[0]); - outputs_[0]->ErrorReceived(this, std::move(error)); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { EVENT(span_, "InputFinished", {{"batches.length", total_batches}}); DCHECK_EQ(input, inputs_[0]); if (input_counter_.SetTotal(total_batches)) { - ErrorIfNotOk(Finish()); + return Finish(); } + return Status::OK(); } Status StartProducing() override { @@ -197,8 +192,7 @@ class ScalarAggregateNode : public ExecNode { {"node.detail", ToString()}, {"node.kind", kind_name()}}); // Scalar aggregates will only output a single batch - outputs_[0]->InputFinished(this, 1); - return Status::OK(); + return output_->InputFinished(this, 1); } void PauseProducing(ExecNode* output, int32_t counter) override { @@ -209,18 +203,7 @@ class ScalarAggregateNode : public ExecNode { inputs_[0]->ResumeProducing(this, counter); } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override { - EVENT(span_, "StopProducing"); - if (input_counter_.Cancel()) { - finished_.MarkFinished(); - } - inputs_[0]->StopProducing(this); - } + void Abort() override {} protected: std::string ToStringExtra(int indent = 0) const override { @@ -251,9 +234,7 @@ class ScalarAggregateNode : public ExecNode { RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i])); } - outputs_[0]->InputReceived(this, std::move(batch)); - finished_.MarkFinished(); - return Status::OK(); + return output_->InputReceived(this, std::move(batch)); } const std::vector target_field_ids_; @@ -271,8 +252,7 @@ class GroupByNode : public ExecNode { std::vector key_field_ids, std::vector agg_src_field_ids, std::vector aggs, std::vector agg_kernels) - : ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema), - /*num_outputs=*/1), + : ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)), ctx_(ctx), key_field_ids_(std::move(key_field_ids)), agg_src_field_ids_(std::move(agg_src_field_ids)), @@ -280,15 +260,11 @@ class GroupByNode : public ExecNode { agg_kernels_(std::move(agg_kernels)) {} Status Init() override { + RETURN_NOT_OK(ExecNode::Init()); output_task_group_id_ = plan_->RegisterTaskGroup( - [this](size_t, int64_t task_id) { - OutputNthBatch(task_id); - return Status::OK(); - }, - [this](size_t) { - finished_.MarkFinished(); - return Status::OK(); - }); + [this](size_t, int64_t task_id) { return OutputNthBatch(task_id); }, + [](size_t) { return Status::OK(); }); + local_states_.resize(plan_->max_concurrency()); return Status::OK(); } @@ -475,12 +451,9 @@ class GroupByNode : public ExecNode { return out_data; } - void OutputNthBatch(int64_t n) { - // bail if StopProducing was called - if (finished_.is_finished()) return; - + Status OutputNthBatch(int64_t n) { int64_t batch_size = output_batch_size(); - outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size)); + return output_->InputReceived(this, out_data_.Slice(batch_size * n, batch_size)); } Status OutputResult() { @@ -496,12 +469,11 @@ class GroupByNode : public ExecNode { ARROW_ASSIGN_OR_RAISE(out_data_, Finalize()); int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size()); - outputs_[0]->InputFinished(this, static_cast(num_output_batches)); - RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_, num_output_batches)); - return Status::OK(); + RETURN_NOT_OK(output_->InputFinished(this, static_cast(num_output_batches))); + return plan_->StartTaskGroup(output_task_group_id_, num_output_batches); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); util::tracing::Span span; START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived", @@ -509,37 +481,24 @@ class GroupByNode : public ExecNode { {"node.label", label()}, {"batch.length", batch.length}}); - // bail if StopProducing was called - if (finished_.is_finished()) return; - DCHECK_EQ(input, inputs_[0]); - if (ErrorIfNotOk(Consume(ExecSpan(batch)))) return; - + RETURN_NOT_OK(Consume(ExecSpan(batch))); if (input_counter_.Increment()) { - ErrorIfNotOk(OutputResult()); + return OutputResult(); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - EVENT(span_, "ErrorReceived", {{"error", error.message()}}); - - DCHECK_EQ(input, inputs_[0]); - - outputs_[0]->ErrorReceived(this, std::move(error)); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { EVENT(span_, "InputFinished", {{"batches.length", total_batches}}); - // bail if StopProducing was called - if (finished_.is_finished()) return; - DCHECK_EQ(input, inputs_[0]); if (input_counter_.SetTotal(total_batches)) { - ErrorIfNotOk(OutputResult()); + return OutputResult(); } + return Status::OK(); } Status StartProducing() override { @@ -548,7 +507,6 @@ class GroupByNode : public ExecNode { {"node.detail", ToString()}, {"node.kind", kind_name()}}); - local_states_.resize(plan_->max_concurrency()); return Status::OK(); } @@ -562,15 +520,7 @@ class GroupByNode : public ExecNode { // Without spillover there is way to handle backpressure in this node } - void StopProducing(ExecNode* output) override { - EVENT(span_, "StopProducing"); - DCHECK_EQ(output, outputs_[0]); - - if (input_counter_.Cancel()) finished_.MarkFinished(); - inputs_[0]->StopProducing(this); - } - - void StopProducing() override { StopProducing(outputs_[0]); } + void Abort() override { input_counter_.Cancel(); } protected: std::string ToStringExtra(int indent = 0) const override { diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 3565d3b4d4a..fd9bf9e1e2b 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -794,26 +794,17 @@ class AsofJoinNode : public ExecNode { } } - void Process() { + Status Process() { std::lock_guard guard(gate_); - if (finished_.is_finished()) { - return; - } - // Process batches while we have data for (;;) { - Result> result = ProcessInner(); - - if (result.ok()) { - auto out_rb = *result; - if (!out_rb) break; - ++batches_produced_; - ExecBatch out_b(*out_rb); - outputs_[0]->InputReceived(this, std::move(out_b)); - } else { - ErrorIfNotOk(result.status()); - return; - } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, ProcessInner()); + + if (!result) break; + ++batches_produced_; + ExecBatch out_b(*result); + + return output_->InputReceived(this, std::move(out_b)); } // Report to the output the total batch count, if we've already finished everything @@ -822,9 +813,9 @@ class AsofJoinNode : public ExecNode { // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) if (state_.at(0)->Finished()) { - outputs_[0]->InputFinished(this, batches_produced_); - finished_.MarkFinished(); + return output_->InputFinished(this, batches_produced_); } + return Status::OK(); } void ProcessThread() { @@ -832,7 +823,10 @@ class AsofJoinNode : public ExecNode { if (!process_.Pop()) { return; } - Process(); + Status st = Process(); + if (!st.ok()) { + plan_->Abort(); + } } } @@ -1075,24 +1069,19 @@ class AsofJoinNode : public ExecNode { const char* kind_name() const override { return "AsofJoinNode"; } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { // Get the input ARROW_DCHECK(std_has(inputs_, input)); size_t k = std_find(inputs_, input) - inputs_.begin(); // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - Status st = state_.at(k)->Push(rb); - if (!st.ok()) { - ErrorReceived(input, st); - return; - } + RETURN_NOT_OK(state_.at(k)->Push(rb)); process_.Push(true); + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - outputs_[0]->ErrorReceived(this, std::move(error)); - } - void InputFinished(ExecNode* input, int total_batches) override { + + Status InputFinished(ExecNode* input, int total_batches) override { { std::lock_guard guard(gate_); ARROW_DCHECK(std_has(inputs_, input)); @@ -1104,22 +1093,18 @@ class AsofJoinNode : public ExecNode { // know whether the RHS of the join is up-to-date until we know that the table is // finished. process_.Push(true); - } - Status StartProducing() override { - finished_ = arrow::Future<>::Make(); return Status::OK(); } - void PauseProducing(ExecNode* output, int32_t counter) override {} - void ResumeProducing(ExecNode* output, int32_t counter) override {} - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - void StopProducing() override { + + Status StartProducing() override { return Status::OK(); } + + void Abort() override { process_.Clear(); process_.Push(false); } - arrow::Future<> finished() override { return finished_; } + + void PauseProducing(ExecNode* output, int32_t counter) override {} + void ResumeProducing(ExecNode* output, int32_t counter) override {} private: arrow::Future<> finished_; @@ -1128,6 +1113,7 @@ class AsofJoinNode : public ExecNode { std::vector> key_hashers_; bool must_hash_; bool may_rehash_; + // InputStates // Each input state correponds to an input table std::vector> state_; @@ -1152,8 +1138,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector> key_hashers, bool must_hash, bool may_rehash) : ExecNode(plan, inputs, input_labels, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), + /*output_schema=*/std::move(output_schema)), indices_of_on_key_(std::move(indices_of_on_key)), indices_of_by_key_(std::move(indices_of_by_key)), key_hashers_(std::move(key_hashers)), @@ -1162,7 +1147,6 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { - finished_ = arrow::Future<>::MakeFinished(); } namespace internal { diff --git a/cpp/src/arrow/compute/exec/benchmark_util.cc b/cpp/src/arrow/compute/exec/benchmark_util.cc index dcc7ca6e165..7e55e84f2d9 100644 --- a/cpp/src/arrow/compute/exec/benchmark_util.cc +++ b/cpp/src/arrow/compute/exec/benchmark_util.cc @@ -67,11 +67,11 @@ Status BenchmarkIsolatedNodeOverhead(benchmark::State& state, int task_group_id = scheduler->RegisterTaskGroup( [&](size_t thread_id, int64_t task_id) { - node->InputReceived(source_node, data.batches[task_id]); - return Status::OK(); + return node->InputReceived(source_node, data.batches[task_id]); }, [&](size_t thread_id) { - node->InputFinished(source_node, static_cast(data.batches.size())); + RETURN_NOT_OK( + node->InputFinished(source_node, static_cast(data.batches.size()))); std::unique_lock lk(mutex); all_tasks_finished_cv.notify_one(); return Status::OK(); @@ -96,9 +96,6 @@ Status BenchmarkIsolatedNodeOverhead(benchmark::State& state, ARROW_RETURN_NOT_OK( scheduler->StartTaskGroup(thread_indexer(), task_group_id, num_batches)); all_tasks_finished_cv.wait(lk); - if (!node->finished().is_finished()) { - return Status::Invalid("All tasks were finsihed but the node was not finished"); - } } state.counters["rows_per_second"] = benchmark::Counter( static_cast(state.iterations() * num_batches * batch_size), diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 83cd6b1ba80..e018c2ead51 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -18,6 +18,7 @@ #include "arrow/compute/exec/exec_plan.h" #include +#include #include #include #include @@ -56,29 +57,27 @@ struct ExecPlanImpl : public ExecPlan { ~ExecPlanImpl() override { if (started_ && !finished_.is_finished()) { ARROW_LOG(WARNING) << "Plan was destroyed before finishing"; - StopProducing(); + Abort(); finished().Wait(); } } size_t GetThreadIndex() { return thread_indexer_(); } size_t max_concurrency() const { return thread_indexer_.Capacity(); } + const std::vector>& nodes() const { return nodes_; } ExecNode* AddNode(std::unique_ptr node) { if (node->label().empty()) { node->SetLabel(std::to_string(auto_label_counter_++)); } - if (node->num_inputs() == 0) { - sources_.push_back(node.get()); - } - if (node->num_outputs() == 0) { - sinks_.push_back(node.get()); - } nodes_.push_back(std::move(node)); return nodes_.back().get(); } Result> BeginExternalTask() { + // The task group isn't relevant in synchronous execution mode + if (!exec_context_->executor()) return Future<>::Make(); + Future<> completion_future = Future<>::Make(); if (async_scheduler_->AddSimpleTask( [completion_future] { return completion_future; })) { @@ -93,7 +92,19 @@ struct ExecPlanImpl : public ExecPlan { // Adds a task which submits fn to the executor and tracks its progress. If we're // aborted then the task is ignored and fn is not executed. async_scheduler_->AddSimpleTask( - [executor, fn]() { return executor->Submit(std::move(fn)); }); + [this, executor, fn = std::move(fn)]() + { + ARROW_ASSIGN_OR_RAISE(Future<> fut, + executor->Submit(stop_source_.token(), std::move(fn))); + fut.AddCallback([this](const Status& status) { + if (!status.ok()) { + std::lock_guard guard(abort_mutex_); + errors_.emplace_back(std::move(status)); + AbortUnlocked(); + } + }); + return Status::OK(); + }); return Status::OK(); } @@ -143,16 +154,7 @@ struct ExecPlanImpl : public ExecPlan { return Status::Invalid("restarted ExecPlan"); } - std::vector> futures; - for (auto& n : nodes_) { - RETURN_NOT_OK(n->Init()); - futures.push_back(n->finished()); - } - - AllFinished(futures).AddCallback([this](const Status& st) { - error_st_ = st; - EndTaskGroup(); - }); + for (std::unique_ptr& n : nodes_) RETURN_NOT_OK(n->Init()); task_scheduler_->RegisterEnd(); int num_threads = 1; @@ -169,163 +171,77 @@ struct ExecPlanImpl : public ExecPlan { /*concurrent_tasks=*/2 * num_threads, sync_execution)); started_ = true; - // producers precede consumers - sorted_nodes_ = TopoSort(); - - Status st = Status::OK(); - - using rev_it = std::reverse_iterator; - for (rev_it it(sorted_nodes_.end()), end(sorted_nodes_.begin()); it != end; ++it) { - auto node = *it; - - EVENT(span_, "StartProducing:" + node->label(), - {{"node.label", node->label()}, {"node.kind_name", node->kind_name()}}); - st = node->StartProducing(); - EVENT(span_, "StartProducing:" + node->label(), {{"status", st.ToString()}}); + for (std::unique_ptr& n : nodes_) { + Status st = n->StartProducing(); if (!st.ok()) { - // Stop nodes that successfully started, in reverse order - stopped_ = true; - StopProducingImpl(it.base(), sorted_nodes_.end()); - for (NodeVector::iterator fw_it = sorted_nodes_.begin(); fw_it != it.base(); - ++fw_it) { - Future<> fut = (*fw_it)->finished(); - if (!fut.is_finished()) fut.MarkFinished(); - } + Abort(); return st; } } - return st; + // StartProducing will have added some tasks to the task group. + // Now we end the task group so that as soon as we run out of tasks, + // we've finished executing. + EndTaskGroup(); + return Status::OK(); } void EndTaskGroup() { bool expected = false; if (group_ended_.compare_exchange_strong(expected, true)) { - async_scheduler_->End(); - async_scheduler_->OnFinished().AddCallback([this](const Status& st) { + async_scheduler_->End(); + async_scheduler_->OnFinished().AddCallback([this](const Status& st) { MARK_SPAN(span_, error_st_ & st); END_SPAN(span_); - finished_.MarkFinished(error_st_ & st); + if (aborted_) { + for (std::unique_ptr& node : nodes_) node->Abort(); + } + if (!errors_.empty()) + finished_.MarkFinished(errors_[0]); + else + finished_.MarkFinished(st); }); } } - void StopProducing() { - DCHECK(started_) << "stopped an ExecPlan which never started"; - EVENT(span_, "StopProducing"); - stopped_ = true; - task_scheduler_->Abort( - [this]() { StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end()); }); + void Abort() { + DCHECK(started_) << "aborted an ExecPlan which never started"; + EVENT(span_, "Abort"); + if (finished_.is_finished()) return; + std::lock_guard guard(abort_mutex_); + AbortUnlocked(); } - template - void StopProducingImpl(It begin, It end) { - for (auto it = begin; it != end; ++it) { - auto node = *it; - EVENT(span_, "StopProducing:" + node->label(), - {{"node.label", node->label()}, {"node.kind_name", node->kind_name()}}); - node->StopProducing(); + void AbortUnlocked() { + if (!aborted_) { + aborted_ = true; + stop_source_.RequestStop(); + EndTaskGroup(); + task_scheduler_->Abort([]() {}); } } - NodeVector TopoSort() const { - struct Impl { - const std::vector>& nodes; - std::unordered_set visited; - NodeVector sorted; - - explicit Impl(const std::vector>& nodes) : nodes(nodes) { - visited.reserve(nodes.size()); - sorted.resize(nodes.size()); - - for (const auto& node : nodes) { - Visit(node.get()); - } - - DCHECK_EQ(visited.size(), nodes.size()); - } - - void Visit(ExecNode* node) { - if (visited.count(node) != 0) return; - - for (auto input : node->inputs()) { - // Ensure that producers are inserted before this consumer - Visit(input); - } - - sorted[visited.size()] = node; - visited.insert(node); - } - }; - - return std::move(Impl{nodes_}.sorted); - } - - // This function returns a node vector and a vector of integers with the - // number of spaces to add as an indentation. The main difference between - // this function and the TopoSort function is that here we visit the nodes - // in reverse order and we can have repeated nodes if necessary. - // For example, in the following plan: - // s1 --> s3 - - // - - - // - -> s5 --> s6 - // - - - // s2 --> s4 - - // Toposort node vector: s1 s2 s3 s4 s5 s6 - // OrderedNodes node vector: s6 s5 s3 s1 s4 s2 s1 - std::pair> OrderedNodes() const { - struct Impl { - const std::vector>& nodes; - std::unordered_set visited; - std::unordered_set marked; - NodeVector sorted; - std::vector indents; - - explicit Impl(const std::vector>& nodes) : nodes(nodes) { - visited.reserve(nodes.size()); - - for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { - if (visited.count(it->get()) != 0) continue; - Visit(it->get()); - } - - DCHECK_EQ(visited.size(), nodes.size()); - } - - void Visit(ExecNode* node, int indent = 0) { - marked.insert(node); - for (auto input : node->inputs()) { - if (marked.count(input) != 0) continue; - Visit(input, indent + 1); - } - marked.erase(node); - - indents.push_back(indent); - sorted.push_back(node); - visited.insert(node); - } - }; - - auto result = Impl{nodes_}; - return std::make_pair(result.sorted, result.indents); - } - std::string ToString() const { std::stringstream ss; ss << "ExecPlan with " << nodes_.size() << " nodes:" << std::endl; - auto sorted = OrderedNodes(); - for (size_t i = sorted.first.size(); i > 0; --i) { - for (int j = 0; j < sorted.second[i - 1]; ++j) ss << " "; - ss << sorted.first[i - 1]->ToString(sorted.second[i - 1]) << std::endl; + for (const std::unique_ptr& node : nodes_) { + if (!node->output()) { + PrintSubtree(node.get(), ss, /*indent=*/0); + } } return ss.str(); } - Status error_st_; + void PrintSubtree(ExecNode* node, std::stringstream& ss, int indent) const { + for (int i = 0; i < indent; i++) ss << " "; + ss << node->ToString(indent) << std::endl; + for (ExecNode* input : node->inputs()) { + PrintSubtree(input, ss, indent + 1); + } + } + Future<> finished_ = Future<>::Make(); - bool started_ = false, stopped_ = false; + bool started_ = false; std::vector> nodes_; - NodeVector sources_, sinks_; - NodeVector sorted_nodes_; uint32_t auto_label_counter_ = 0; util::tracing::Span span_; std::shared_ptr metadata_; @@ -335,6 +251,11 @@ struct ExecPlanImpl : public ExecPlan { std::unique_ptr async_scheduler_ = util::AsyncTaskScheduler::Make(); std::unique_ptr task_scheduler_ = TaskScheduler::Make(); + + std::mutex abort_mutex_; + bool aborted_ = false; + StopSource stop_source_; + std::vector errors_; }; ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast(ptr); } @@ -364,14 +285,11 @@ ExecNode* ExecPlan::AddNode(std::unique_ptr node) { return ToDerived(this)->AddNode(std::move(node)); } -const ExecPlan::NodeVector& ExecPlan::sources() const { - return ToDerived(this)->sources_; -} - -const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; } - size_t ExecPlan::GetThreadIndex() { return ToDerived(this)->GetThreadIndex(); } size_t ExecPlan::max_concurrency() const { return ToDerived(this)->max_concurrency(); } +const std::vector>& ExecPlan::nodes() const { + return ToDerived(this)->nodes(); +} Result> ExecPlan::BeginExternalTask() { return ToDerived(this)->BeginExternalTask(); @@ -398,8 +316,7 @@ util::AsyncTaskScheduler* ExecPlan::async_scheduler() { Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } - -void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); } +void ExecPlan::Abort() { ToDerived(this)->Abort(); } Future<> ExecPlan::finished() { return ToDerived(this)->finished_; } @@ -413,18 +330,23 @@ std::string ExecPlan::ToString() const { return ToDerived(this)->ToString(); } ExecNode::ExecNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - std::shared_ptr output_schema, int num_outputs) + std::shared_ptr output_schema) : plan_(plan), inputs_(std::move(inputs)), input_labels_(std::move(input_labels)), output_schema_(std::move(output_schema)), - num_outputs_(num_outputs) { + output_(nullptr) { for (auto input : inputs_) { - input->outputs_.push_back(this); + input->output_ = this; } } -Status ExecNode::Init() { return Status::OK(); } +Status ExecNode::Init() { + START_COMPUTE_SPAN( + span_, std::string(kind_name()) + ":" + label(), + {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", kind_name()}}); + return Status::OK(); +} Status ExecNode::Validate() const { if (inputs_.size() != input_labels_.size()) { @@ -432,19 +354,12 @@ Status ExecNode::Validate() const { num_inputs(), ", actual ", input_labels_.size(), ")"); } - if (static_cast(outputs_.size()) != num_outputs_) { - return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ", - num_outputs(), ", actual ", outputs_.size(), ")"); - } - - for (auto out : outputs_) { - auto input_index = GetNodeIndex(out->inputs(), this); - if (!input_index) { - return Status::Invalid("Node '", label(), "' outputs to node '", out->label(), + if (output_) { + auto input_index = GetNodeIndex(output_->inputs(), this); + if (!input_index) + return Status::Invalid("Node '", label(), "' outputs to node '", output_->label(), "' but is not listed as an input."); - } } - return Status::OK(); } @@ -469,15 +384,6 @@ std::string ExecNode::ToString(int indent) const { std::string ExecNode::ToStringExtra(int indent = 0) const { return ""; } -bool ExecNode::ErrorIfNotOk(Status status) { - if (status.ok()) return false; - - for (auto out : outputs_) { - out->ErrorReceived(this, out == outputs_.back() ? std::move(status) : status); - } - return true; -} - std::shared_ptr MakeGeneratorReader( std::shared_ptr schema, std::function>()> gen, MemoryPool* pool) { diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 5d929aa3057..2c16dd9580c 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -74,6 +74,8 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { /// e.g. make an array of thread-locals off this. size_t max_concurrency() const; + const std::vector>& nodes() const; + /// \brief Start an external task /// /// This should be avoided if possible. It is kept in for now for legacy @@ -122,12 +124,6 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { util::AsyncTaskScheduler* async_scheduler(); - /// The initial inputs - const NodeVector& sources() const; - - /// The final outputs - const NodeVector& sinks() const; - Status Validate(); /// \brief Start producing on all nodes @@ -136,11 +132,7 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { /// is started before all of its inputs. Status StartProducing(); - /// \brief Stop producing on all nodes - /// - /// Nodes are stopped in topological order, such that any node - /// is stopped before all of its outputs. - void StopProducing(); + void Abort(); /// \brief A future which will be marked finished when all nodes have stopped producing. Future<> finished(); @@ -182,7 +174,6 @@ class ARROW_EXPORT ExecNode { // The number of inputs/outputs expected by this node int num_inputs() const { return static_cast(inputs_.size()); } - int num_outputs() const { return num_outputs_; } /// This node's predecessors in the exec plan const NodeVector& inputs() const { return inputs_; } @@ -190,8 +181,8 @@ class ARROW_EXPORT ExecNode { /// \brief Labels identifying the function of each input. const std::vector& input_labels() const { return input_labels_; } - /// This node's successors in the exec plan - const NodeVector& outputs() const { return outputs_; } + /// This node's successor in the exec plan + const ExecNode* output() const { return output_; } /// The datatypes for batches produced by this node const std::shared_ptr& output_schema() const { return output_schema_; } @@ -220,17 +211,14 @@ class ARROW_EXPORT ExecNode { /// and StopProducing() /// Transfer input batch to ExecNode - virtual void InputReceived(ExecNode* input, ExecBatch batch) = 0; - - /// Signal error to ExecNode - virtual void ErrorReceived(ExecNode* input, Status error) = 0; + virtual Status InputReceived(ExecNode* input, ExecBatch batch) = 0; /// Mark the inputs finished after the given number of batches. /// /// This may be called before all inputs are received. This simply fixes /// the total number of incoming batches for an input, so that the ExecNode /// knows when it has received all input, regardless of order. - virtual void InputFinished(ExecNode* input, int total_batches) = 0; + virtual Status InputFinished(ExecNode* input, int total_batches) = 0; /// \brief Perform any needed initialization /// @@ -328,27 +316,14 @@ class ARROW_EXPORT ExecNode { /// This may be called any number of times after StartProducing() succeeds. virtual void ResumeProducing(ExecNode* output, int32_t counter) = 0; - /// \brief Stop producing definitively to a single output - /// - /// This call is a hint that an output node has completed and is not willing - /// to receive any further data. - virtual void StopProducing(ExecNode* output) = 0; - - /// \brief Stop producing definitively to all outputs - virtual void StopProducing() = 0; - - /// \brief A future which will be marked finished when this node has stopped producing. - virtual Future<> finished() { return finished_; } + /// \brief Abort execution and perform any needed cleanup (such as closing files, etc.) + virtual void Abort() = 0; std::string ToString(int indent = 0) const; protected: ExecNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - std::shared_ptr output_schema, int num_outputs); - - // A helper method to send an error status to all outputs. - // Returns true if the status was an error. - bool ErrorIfNotOk(Status status); + std::shared_ptr output_schema); /// Provide extra info to include in the string representation. virtual std::string ToStringExtra(int indent) const; @@ -360,11 +335,7 @@ class ARROW_EXPORT ExecNode { std::vector input_labels_; std::shared_ptr output_schema_; - int num_outputs_; - NodeVector outputs_; - - // Future to sync finished - Future<> finished_ = Future<>::Make(); + ExecNode* output_; util::tracing::Span span_; }; diff --git a/cpp/src/arrow/compute/exec/filter_node.cc b/cpp/src/arrow/compute/exec/filter_node.cc index 19d18ca6081..72879635f13 100644 --- a/cpp/src/arrow/compute/exec/filter_node.cc +++ b/cpp/src/arrow/compute/exec/filter_node.cc @@ -19,7 +19,6 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" -#include "arrow/compute/exec/map_node.h" #include "arrow/compute/exec/options.h" #include "arrow/datum.h" #include "arrow/result.h" @@ -35,16 +34,16 @@ using internal::checked_cast; namespace compute { namespace { -class FilterNode : public MapNode { +class FilterNode : public ExecNode { public: FilterNode(ExecPlan* plan, std::vector inputs, std::shared_ptr output_schema, Expression filter) - : MapNode(plan, std::move(inputs), std::move(output_schema)), + : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"}, + std::move(output_schema)), filter_(std::move(filter)) {} static Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { - RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "FilterNode")); auto schema = inputs[0]->output_schema(); const auto& filter_options = checked_cast(options); @@ -66,32 +65,31 @@ class FilterNode : public MapNode { const char* kind_name() const override { return "FilterNode"; } - Result DoFilter(const ExecBatch& target) { + Result DoFilter(ExecBatch batch) { ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, - SimplifyWithGuarantee(filter_, target.guarantee)); - + SimplifyWithGuarantee(filter_, batch.guarantee)); util::tracing::Span span; START_COMPUTE_SPAN(span, "Filter", {{"filter.expression", ToStringExtra()}, {"filter.expression.simplified", simplified_filter.ToString()}, - {"filter.length", target.length}}); + {"filter.length", batch.length}}); - ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target, + ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, batch, plan()->exec_context())); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); if (mask_scalar.is_valid && mask_scalar.value) { - return target; + return batch; } - return target.Slice(0, 0); + return batch.Slice(0, 0); } // if the values are all scalar then the mask must also be - DCHECK(!std::all_of(target.values.begin(), target.values.end(), + DCHECK(!std::all_of(batch.values.begin(), batch.values.end(), [](const Datum& value) { return value.is_scalar(); })); - auto values = target.values; + auto values = batch.values; for (auto& value : values) { if (value.is_scalar()) continue; ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults())); @@ -99,23 +97,30 @@ class FilterNode : public MapNode { return ExecBatch::Make(std::move(values)); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status StartProducing() override { return Status::OK(); } + + void PauseProducing(ExecNode* output, int32_t counter) override { + inputs_[0]->PauseProducing(this, counter); + } + + void ResumeProducing(ExecNode* output, int32_t counter) override { + inputs_[0]->ResumeProducing(this, counter); + } + + Status InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); DCHECK_EQ(input, inputs_[0]); - auto func = [this](ExecBatch batch) { - util::tracing::Span span; - START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived", - {{"filter", ToStringExtra()}, - {"node.label", label()}, - {"batch.length", batch.length}}); - auto result = DoFilter(std::move(batch)); - MARK_SPAN(span, result.status()); - END_SPAN(span); - return result; - }; - this->SubmitTask(std::move(func), std::move(batch)); + ARROW_ASSIGN_OR_RAISE(ExecBatch filtered, DoFilter(std::move(batch))); + return output_->InputReceived(this, std::move(filtered)); + } + + Status InputFinished(ExecNode* input, int total_batches) override { + END_SPAN(span_); + return output_->InputFinished(this, total_batches); } + void Abort() override {} + protected: std::string ToStringExtra(int indent = 0) const override { return "filter=" + filter_.ToString(); diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index da1710fe08d..1d9d6be6050 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -181,9 +181,9 @@ class HashJoinBasicImpl : public HashJoinImpl { } } - void ProbeBatch_OutputOne(int64_t batch_size_next, ExecBatch* opt_left_key, - ExecBatch* opt_left_payload, ExecBatch* opt_right_key, - ExecBatch* opt_right_payload) { + Status ProbeBatch_OutputOne(int64_t batch_size_next, ExecBatch* opt_left_key, + ExecBatch* opt_left_payload, ExecBatch* opt_right_key, + ExecBatch* opt_right_payload) { ExecBatch result({}, batch_size_next); int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT); int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT); @@ -232,11 +232,12 @@ class HashJoinBasicImpl : public HashJoinImpl { : opt_right_payload->values[from_payload.get(icol)]; } - output_batch_callback_(0, std::move(result)); + RETURN_NOT_OK(output_batch_callback_(0, std::move(result))); // Update the counter of produced batches // num_batches_produced_++; + return Status::OK(); } Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state, @@ -404,12 +405,10 @@ class HashJoinBasicImpl : public HashJoinImpl { hash_table_payloads_.Decode(batch_size_next, opt_right_ids)); } - ProbeBatch_OutputOne(batch_size_next, has_left ? &left_key : nullptr, - has_left_payload ? &left_payload : nullptr, - has_right ? &right_key : nullptr, - has_right_payload ? &right_payload : nullptr); - - return Status::OK(); + return ProbeBatch_OutputOne(batch_size_next, has_left ? &left_key : nullptr, + has_left_payload ? &left_payload : nullptr, + has_right ? &right_key : nullptr, + has_right_payload ? &right_payload : nullptr); } Status ProbeBatch_OutputAll(size_t thread_index, const RowEncoder& exec_batch_keys, @@ -683,8 +682,7 @@ class HashJoinBasicImpl : public HashJoinImpl { return Status::Cancelled("Hash join cancelled"); } END_SPAN(span_); - finished_callback_(num_batches_produced_.load()); - return Status::OK(); + return finished_callback_(num_batches_produced_.load()); } Status ScanHashTable(size_t thread_index) { diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 0c5e43467e9..70bd85fd051 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -38,9 +38,9 @@ using arrow::util::AccumulationQueue; class HashJoinImpl { public: - using OutputBatchCallback = std::function; + using OutputBatchCallback = std::function; using BuildFinishedCallback = std::function; - using FinishedCallback = std::function; + using FinishedCallback = std::function; using RegisterTaskGroupCallback = std::function, std::function)>; using StartTaskGroupCallback = std::function; diff --git a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc index a59ec03749a..1537c1e0e4e 100644 --- a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc @@ -162,8 +162,9 @@ class JoinBenchmark { ctx_.get(), settings.join_type, settings.num_threads, &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), std::move(key_cmp), std::move(filter), std::move(register_task_group_callback), - std::move(start_task_group_callback), [](int64_t, ExecBatch) {}, - [](int64_t x) {})); + std::move(start_task_group_callback), + [](int64_t, ExecBatch) { return Status::OK(); }, + [](int64_t x) { return Status::OK(); })); task_group_probe_ = scheduler_->RegisterTaskGroup( [this](size_t thread_index, int64_t task_id) -> Status { diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 666ed92ec08..88d3453926d 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -698,8 +698,7 @@ class HashJoinNode : public ExecNode { std::unique_ptr schema_mgr, Expression filter, std::unique_ptr impl) : ExecNode(plan, inputs, {"left", "right"}, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), + /*output_schema=*/std::move(output_schema)), join_type_(join_options.join_type), key_cmp_(join_options.key_cmp), filter_(std::move(filter)), @@ -885,10 +884,10 @@ class HashJoinNode : public ExecNode { return Status::OK(); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); if (complete_.load()) { - return; + return Status::OK(); } size_t thread_index = plan_->GetThreadIndex(); @@ -899,35 +898,19 @@ class HashJoinNode : public ExecNode { START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived", {{"batch.length", batch.length}}); - Status status = side == 0 ? OnProbeSideBatch(thread_index, std::move(batch)) - : OnBuildSideBatch(thread_index, std::move(batch)); - - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - return; - } + if (side == 0) + RETURN_NOT_OK(OnProbeSideBatch(thread_index, std::move(batch))); + else + RETURN_NOT_OK(OnBuildSideBatch(thread_index, std::move(batch))); if (batch_count_[side].Increment()) { - status = side == 0 ? OnProbeSideFinished(thread_index) - : OnBuildSideFinished(thread_index); - - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - return; - } + return side == 0 ? OnProbeSideFinished(thread_index) + : OnBuildSideFinished(thread_index); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - EVENT(span_, "ErrorReceived", {{"error", error.message()}}); - DCHECK_EQ(input, inputs_[0]); - StopProducing(); - outputs_[0]->ErrorReceived(this, std::move(error)); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); size_t thread_index = plan_->GetThreadIndex(); int side = (input == inputs_[0]) ? 0 : 1; @@ -935,19 +918,15 @@ class HashJoinNode : public ExecNode { EVENT(span_, "InputFinished", {{"side", side}, {"batches.length", total_batches}}); if (batch_count_[side].SetTotal(total_batches)) { - Status status = side == 0 ? OnProbeSideFinished(thread_index) - : OnBuildSideFinished(thread_index); - - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - return; - } + return side == 0 ? OnProbeSideFinished(thread_index) + : OnBuildSideFinished(thread_index); } + return Status::OK(); } Status Init() override { RETURN_NOT_OK(ExecNode::Init()); + if (plan_->UseLegacyBatching()) { return Status::Invalid( "The plan was configured to use legacy batching but contained a join node " @@ -981,9 +960,9 @@ class HashJoinNode : public ExecNode { [this](int task_group_id, int64_t num_tasks) { return plan_->StartTaskGroup(task_group_id, num_tasks); }, - [this](int64_t, ExecBatch batch) { this->OutputBatchCallback(batch); }, + [this](int64_t, ExecBatch batch) { return this->OutputBatchCallback(batch); }, [this](int64_t total_num_batches) { - this->FinishedCallback(total_num_batches); + return this->FinishedCallback(total_num_batches); })); task_group_probe_ = plan_->RegisterTaskGroup( @@ -999,11 +978,6 @@ class HashJoinNode : public ExecNode { } Status StartProducing() override { - START_COMPUTE_SPAN(span_, std::string(kind_name()) + ":" + label(), - {{"node.label", label()}, - {"node.detail", ToString()}, - {"node.kind", kind_name()}}); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); RETURN_NOT_OK(pushdown_context_.StartProducing()); return Status::OK(); } @@ -1016,19 +990,10 @@ class HashJoinNode : public ExecNode { // TODO(ARROW-16246) } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override { - EVENT(span_, "StopProducing"); + void Abort() override { bool expected = false; if (complete_.compare_exchange_strong(expected, true)) { - for (auto&& input : inputs_) { - input->StopProducing(this); - } - impl_->Abort([this]() { finished_.MarkFinished(); }); + impl_->Abort([]() {}); } } @@ -1038,16 +1003,17 @@ class HashJoinNode : public ExecNode { } private: - void OutputBatchCallback(ExecBatch batch) { - outputs_[0]->InputReceived(this, std::move(batch)); + Status OutputBatchCallback(ExecBatch batch) { + return output_->InputReceived(this, std::move(batch)); } - void FinishedCallback(int64_t total_num_batches) { + Status FinishedCallback(int64_t total_num_batches) { bool expected = false; if (complete_.compare_exchange_strong(expected, true)) { - outputs_[0]->InputFinished(this, static_cast(total_num_batches)); - finished_.MarkFinished(); + END_SPAN(span_); + return output_->InputFinished(this, static_cast(total_num_batches)); } + return Status::OK(); } private: diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index adc5ec70ebd..d36491ebb97 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -921,7 +921,7 @@ Result> HashJoinWithExecPlan( auto batches_fut = StartAndCollect(plan.get(), sink_gen); if (!batches_fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)) { - plan->StopProducing(); + plan->Abort(); // If this second wait fails then there isn't much we can do. We will abort // and probably get a segmentation fault. plan->finished().Wait(::arrow::kDefaultAssertFinishesWaitSeconds); diff --git a/cpp/src/arrow/compute/exec/map_node.cc b/cpp/src/arrow/compute/exec/map_node.cc deleted file mode 100644 index 16201ea1290..00000000000 --- a/cpp/src/arrow/compute/exec/map_node.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/exec/map_node.h" - -#include -#include -#include -#include -#include - -#include "arrow/compute/exec.h" -#include "arrow/compute/exec/expression.h" -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/util/logging.h" -#include "arrow/util/tracing_internal.h" - -namespace arrow { -namespace compute { - -MapNode::MapNode(ExecPlan* plan, std::vector inputs, - std::shared_ptr output_schema) - : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"}, - std::move(output_schema), - /*num_outputs=*/1) {} - -void MapNode::ErrorReceived(ExecNode* input, Status error) { - DCHECK_EQ(input, inputs_[0]); - EVENT(span_, "ErrorReceived", {{"error.message", error.message()}}); - outputs_[0]->ErrorReceived(this, std::move(error)); -} - -void MapNode::InputFinished(ExecNode* input, int total_batches) { - DCHECK_EQ(input, inputs_[0]); - EVENT(span_, "InputFinished", {{"batches.length", total_batches}}); - outputs_[0]->InputFinished(this, total_batches); - if (input_counter_.SetTotal(total_batches)) { - this->Finish(); - } -} - -Status MapNode::StartProducing() { - START_COMPUTE_SPAN( - span_, std::string(kind_name()) + ":" + label(), - {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", kind_name()}}); - return Status::OK(); -} - -void MapNode::PauseProducing(ExecNode* output, int32_t counter) { - inputs_[0]->PauseProducing(this, counter); -} - -void MapNode::ResumeProducing(ExecNode* output, int32_t counter) { - inputs_[0]->ResumeProducing(this, counter); -} - -void MapNode::StopProducing(ExecNode* output) { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); -} - -void MapNode::StopProducing() { - EVENT(span_, "StopProducing"); - if (input_counter_.Cancel()) { - this->Finish(); - } - inputs_[0]->StopProducing(this); -} - -void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, - ExecBatch batch) { - Status status; - // This will be true if the node is stopped early due to an error or manual - // cancellation - if (input_counter_.Completed()) { - return; - } - auto task = [this, map_fn, batch]() { - auto guarantee = batch.guarantee; - auto output_batch = map_fn(std::move(batch)); - if (ErrorIfNotOk(output_batch.status())) { - return output_batch.status(); - } - output_batch->guarantee = guarantee; - outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe()); - return Status::OK(); - }; - - status = task(); - if (!status.ok()) { - if (input_counter_.Cancel()) { - this->Finish(status); - } - inputs_[0]->StopProducing(this); - return; - } - if (input_counter_.Increment()) { - this->Finish(); - } -} - -void MapNode::Finish(Status finish_st /*= Status::OK()*/) { - this->finished_.MarkFinished(finish_st); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/map_node.h b/cpp/src/arrow/compute/exec/map_node.h deleted file mode 100644 index 88241ece592..00000000000 --- a/cpp/src/arrow/compute/exec/map_node.h +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -/// \brief MapNode is an ExecNode type class which process a task like filter/project -/// (See SubmitTask method) to each given ExecBatch object, which have one input, one -/// output, and are pure functions on the input -/// -/// A simple parallel runner is created with a "map_fn" which is just a function that -/// takes a batch in and returns a batch. This simple parallel runner also needs an -/// executor (use simple synchronous runner if there is no executor) - -#pragma once - -#include -#include -#include -#include - -#include "arrow/compute/exec/exec_plan.h" -#include "arrow/compute/exec/util.h" -#include "arrow/compute/type_fwd.h" -#include "arrow/status.h" -#include "arrow/type_fwd.h" -#include "arrow/util/cancel.h" -#include "arrow/util/type_fwd.h" -#include "arrow/util/visibility.h" - -namespace arrow { -namespace compute { - -class ARROW_EXPORT MapNode : public ExecNode { - public: - MapNode(ExecPlan* plan, std::vector inputs, - std::shared_ptr output_schema); - - void ErrorReceived(ExecNode* input, Status error) override; - - void InputFinished(ExecNode* input, int total_batches) override; - - Status StartProducing() override; - - void PauseProducing(ExecNode* output, int32_t counter) override; - - void ResumeProducing(ExecNode* output, int32_t counter) override; - - void StopProducing(ExecNode* output) override; - - void StopProducing() override; - - protected: - void SubmitTask(std::function(ExecBatch)> map_fn, ExecBatch batch); - - virtual void Finish(Status finish_st = Status::OK()); - - protected: - // Counter for the number of batches received - AtomicCounter input_counter_; -}; - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 9f417d46ed3..a7f271014a1 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -56,65 +56,47 @@ TEST(ExecPlanConstruction, Empty) { TEST(ExecPlanConstruction, SingleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/0); + std::ignore = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ElementsAre(node)); - ASSERT_THAT(plan->sinks(), ElementsAre(node)); ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/1); + std::ignore = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}); // Output not bound - ASSERT_THAT(plan->Validate(), Raises(StatusCode::Invalid)); -} - -TEST(ExecPlanConstruction, SourceSink) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*inputs=*/{}, /*num_outputs=*/1); - auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{source}, /*num_outputs=*/0); - - ASSERT_OK(plan->Validate()); - EXPECT_THAT(plan->sources(), ElementsAre(source)); - EXPECT_THAT(plan->sinks(), ElementsAre(sink)); + ASSERT_THAT(plan->Validate(), Raises(StatusCode::OK)); } TEST(ExecPlanConstruction, MultipleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2); + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}); - auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1); + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}); - auto process1 = - MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2); + auto process1 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}); - auto process2 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1, source2}, - /*num_outputs=*/1); + auto process2 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1, source2}); auto process3 = - MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, process2, process1}, - /*num_outputs=*/1); + MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, process2, process1}); - auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0); + std::ignore /*sink*/ = MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ElementsAre(source1, source2)); - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); } TEST(ExecPlanConstruction, AutoLabel) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "", /*inputs=*/{}, /*num_outputs=*/2); - auto source2 = - MakeDummyNode(plan.get(), "some_label", /*inputs=*/{}, /*num_outputs=*/1); - auto source3 = MakeDummyNode(plan.get(), "", /*inputs=*/{}, /*num_outputs=*/2); + auto source1 = MakeDummyNode(plan.get(), "", /*inputs=*/{}); + auto source2 = MakeDummyNode(plan.get(), "some_label", /*inputs=*/{}); + auto source3 = MakeDummyNode(plan.get(), "", /*inputs=*/{}); ASSERT_EQ("0", source1->label()); ASSERT_EQ("some_label", source2->label()); ASSERT_EQ("2", source3->label()); } -struct StartStopTracker { - std::vector started, stopped; +struct StartAbortTracker { + std::vector started, aborted; StartProducingFunc start_producing_func(Status st = Status::OK()) { return [this, st](ExecNode* node) { @@ -123,50 +105,41 @@ struct StartStopTracker { }; } - StopProducingFunc stop_producing_func() { - return [this](ExecNode* node) { stopped.push_back(node->label()); }; + AbortFunc abort_func() { + return [this](ExecNode* node) { aborted.push_back(node->label()); }; } }; TEST(ExecPlan, DummyStartProducing) { - StartStopTracker t; + StartAbortTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2, - t.start_producing_func(), t.stop_producing_func()); + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, + t.start_producing_func(), t.abort_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, + t.start_producing_func(), t.abort_func()); - auto process1 = - MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2, - t.start_producing_func(), t.stop_producing_func()); + auto process1 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, + t.start_producing_func(), t.abort_func()); - auto process2 = - MakeDummyNode(plan.get(), "process2", /*inputs=*/{process1, source2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + auto process2 = MakeDummyNode(plan.get(), "process2", /*inputs=*/{process1, source2}, + t.start_producing_func(), t.abort_func()); auto process3 = MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, source1, process2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + t.start_producing_func(), t.abort_func()); - MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0, - t.start_producing_func(), t.stop_producing_func()); + MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, t.start_producing_func(), + t.abort_func()); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); - ASSERT_EQ(t.stopped.size(), 0); + ASSERT_EQ(t.aborted.size(), 0); ASSERT_OK(plan->StartProducing()); - // Note that any correct reverse topological order may do - ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1", - "source2", "source1")); - - plan->StopProducing(); - ASSERT_THAT(plan->finished(), Finishes(Ok())); - // Note that any correct topological order may do - ASSERT_THAT(t.stopped, ElementsAre("source1", "source2", "process1", "process2", + ASSERT_THAT(t.started, ElementsAre("source1", "source2", "process1", "process2", "process3", "sink")); ASSERT_THAT(plan->StartProducing(), @@ -174,41 +147,40 @@ TEST(ExecPlan, DummyStartProducing) { } TEST(ExecPlan, DummyStartProducingError) { - StartStopTracker t; + StartAbortTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode( - plan.get(), "source1", /*num_inputs=*/{}, /*num_outputs=*/2, - t.start_producing_func(Status::NotImplemented("zzz")), t.stop_producing_func()); + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/{}, + t.start_producing_func(Status::NotImplemented("zzz")), + t.abort_func()); - auto source2 = - MakeDummyNode(plan.get(), "source2", /*num_inputs=*/{}, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/{}, + t.start_producing_func(), t.abort_func()); - auto process1 = MakeDummyNode( - plan.get(), "process1", /*num_inputs=*/{source1}, /*num_outputs=*/2, - t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func()); + auto process1 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/{source1}, + t.start_producing_func(Status::IOError("xxx")), t.abort_func()); auto process2 = MakeDummyNode(plan.get(), "process2", /*num_inputs=*/{process1, source2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + t.start_producing_func(), t.abort_func()); auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/{process1, source1, process2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + t.start_producing_func(), t.abort_func()); - MakeDummyNode(plan.get(), "sink", /*num_inputs=*/{process3}, /*num_outputs=*/0, - t.start_producing_func(), t.stop_producing_func()); + MakeDummyNode(plan.get(), "sink", /*num_inputs=*/{process3}, t.start_producing_func(), + t.abort_func()); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); - ASSERT_EQ(t.stopped.size(), 0); + ASSERT_EQ(t.aborted.size(), 0); - // `process1` raises IOError - ASSERT_THAT(plan->StartProducing(), Raises(StatusCode::IOError)); - ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1")); - // Nodes that started successfully were stopped in reverse order - ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink")); + // `source1` raises NotImplemented + ASSERT_THAT(plan->StartProducing(), Raises(StatusCode::NotImplemented)); + ASSERT_THAT(t.started, ElementsAre("source1")); + ASSERT_THAT(t.aborted, ElementsAre("source1", "source2", "process1", "process2", + "process3", "sink")); } TEST(ExecPlanExecution, SourceSink) { @@ -305,7 +277,8 @@ TEST(ExecPlanExecution, SinkNodeBackpressure) { kPauseIfAbove * static_cast(batch->TotalBufferSize()); uint32_t resume_if_below_bytes = kResumeIfBelow * static_cast(batch->TotalBufferSize()); - EXPECT_OK_AND_ASSIGN(std::shared_ptr plan, ExecPlan::Make()); + ExecContext ctx(default_memory_pool(), arrow::internal::GetCpuThreadPool()); + EXPECT_OK_AND_ASSIGN(std::shared_ptr plan, ExecPlan::Make(&ctx)); PushGenerator> batch_producer; AsyncGenerator> sink_gen; BackpressureMonitor* backpressure_monitor; @@ -349,7 +322,6 @@ TEST(ExecPlanExecution, SinkNodeBackpressure) { // Cleanup batch_producer.producer().Push(IterationEnd>()); - plan->StopProducing(); ASSERT_FINISHES_OK(plan->finished()); } @@ -366,8 +338,6 @@ TEST(ExecPlan, ToString) { {"sink", SinkNodeOptions{&sink_gen}}, }) .AddToPlan(plan.get())); - EXPECT_EQ(plan->sources()[0]->ToString(), R"(:SourceNode{})"); - EXPECT_EQ(plan->sinks()[0]->ToString(), R"(:SinkNode{})"); EXPECT_EQ(plan->ToString(), R"(ExecPlan with 2 nodes: :SinkNode{} :SourceNode{} @@ -445,8 +415,8 @@ custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, count(i32, {mode=NON_NULL}), ]} :UnionNode{} - rhs:SourceNode{} lhs:SourceNode{} + rhs:SourceNode{} )a"); } @@ -511,9 +481,21 @@ TEST(ExecPlanExecution, SourceConsumingSink) { for (bool parallel : {false, true}) { SCOPED_TRACE(parallel ? "parallel" : "single threaded"); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + ExecContext ctx(default_memory_pool(), arrow::internal::GetCpuThreadPool()); + ASSERT_OK_AND_ASSIGN(auto plan, + ExecPlan::Make(parallel ? &ctx : default_exec_context())); std::atomic batches_seen{0}; Future<> finish = Future<>::Make(); + + struct AutoMarker { + explicit AutoMarker(Future<> fut) : fut_(std::move(fut)) {} + ~AutoMarker() { + if (!fut_.is_finished()) fut_.MarkFinished(); + } + Future<> fut_; + } marker(finish); + struct TestConsumer : public SinkNodeConsumer { TestConsumer(std::atomic* batches_seen, Future<> finish) : batches_seen(batches_seen), finish(std::move(finish)) {} @@ -544,12 +526,16 @@ TEST(ExecPlanExecution, SourceConsumingSink) { ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source}, ConsumingSinkNodeOptions(consumer))); ASSERT_OK(plan->StartProducing()); - // Source should finish fairly quickly - ASSERT_FINISHES_OK(source->finished()); + SleepABit(); + SleepABit(); SleepABit(); ASSERT_EQ(2, batches_seen); - // Consumer isn't finished and so plan shouldn't have finished - AssertNotFinished(plan->finished()); + if (parallel) { + // Consumer isn't finished and so plan shouldn't have finished, + // unless we're in synchronous mode in which case the plan must've + // finished by now. + AssertNotFinished(plan->finished()); + } // Mark consumption complete, plan should finish finish.MarkFinished(); ASSERT_FINISHES_OK(plan->finished()); @@ -577,14 +563,13 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) { basic_data.gen(parallel, slow)))); ASSERT_OK(MakeExecNode("table_sink", plan.get(), {source}, options)); ASSERT_OK(plan->StartProducing()); - // Source should finish fairly quickly - ASSERT_FINISHES_OK(source->finished()); SleepABit(); + ASSERT_FINISHES_OK(plan->finished()); + ASSERT_OK_AND_ASSIGN(auto actual, TableFromExecBatches(basic_data.schema, basic_data.batches)); ASSERT_EQ(5, out->num_rows()); AssertTablesEqual(*actual, *out); - ASSERT_FINISHES_OK(plan->finished()); } } } @@ -605,21 +590,21 @@ TEST(ExecPlanExecution, ConsumingSinkNames) { auto consumer = std::make_shared(); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); auto basic_data = MakeBasicBatches(); - ASSERT_OK(Declaration::Sequence( - {{"source", - SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))}, - {"consuming_sink", ConsumingSinkNodeOptions(consumer, names)}}) - .AddToPlan(plan.get())); - ASSERT_OK_AND_ASSIGN( - auto source, - MakeExecNode("source", plan.get(), {}, - SourceNodeOptions(basic_data.schema, basic_data.gen(false, false)))); - ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source}, - ConsumingSinkNodeOptions(consumer, names))); + if (names.size() != 0 && names.size() != static_cast(basic_data.batches[0].num_values())) { - ASSERT_RAISES(Invalid, plan->StartProducing()); + ASSERT_RAISES(Invalid, + Declaration::Sequence( + {{"source", SourceNodeOptions(basic_data.schema, + basic_data.gen(false, false))}, + {"consuming_sink", ConsumingSinkNodeOptions(consumer, names)}}) + .AddToPlan(plan.get())); } else { + ASSERT_OK(Declaration::Sequence( + {{"source", + SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))}, + {"consuming_sink", ConsumingSinkNodeOptions(consumer, names)}}) + .AddToPlan(plan.get())); auto expected_names = names.size() == 0 ? basic_data.schema->field_names() : names; ASSERT_OK(plan->StartProducing()); ASSERT_FINISHES_OK(plan->finished()); @@ -635,7 +620,8 @@ TEST(ExecPlanExecution, ConsumingSinkError) { return Status::Invalid("XYZ"); } Status Consume(ExecBatch batch) override { return Status::OK(); } - Future<> Finish() override { return Future<>::MakeFinished(); } + Future<> Finish() override { return finished_; } + Future<> finished_ = Future<>::MakeFinished(); }; struct ConsumeErrorConsumer : public SinkNodeConsumer { Status Init(const std::shared_ptr& schema, @@ -643,7 +629,8 @@ TEST(ExecPlanExecution, ConsumingSinkError) { return Status::OK(); } Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); } - Future<> Finish() override { return Future<>::MakeFinished(); } + Future<> Finish() override { return finished_; } + Future<> finished_ = Future<>::MakeFinished(); }; struct FinishErrorConsumer : public SinkNodeConsumer { Status Init(const std::shared_ptr& schema, @@ -651,14 +638,17 @@ TEST(ExecPlanExecution, ConsumingSinkError) { return Status::OK(); } Status Consume(ExecBatch batch) override { return Status::OK(); } - Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); } + Future<> Finish() override { return finished_; } + Future<> finished_ = Future<>::MakeFinished(Status::Invalid("XYZ")); }; std::vector> consumers{ std::make_shared(), std::make_shared(), std::make_shared()}; + ExecContext ctx(default_memory_pool(), arrow::internal::GetCpuThreadPool()); + for (auto& consumer : consumers) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&ctx)); auto basic_data = MakeBasicBatches(); ASSERT_OK(Declaration::Sequence( {{"source", @@ -778,7 +768,7 @@ TEST(ExecPlanExecution, StressSourceGroupedSumStop) { ASSERT_OK(plan->Validate()); ASSERT_OK(plan->StartProducing()); - plan->StopProducing(); + plan->Abort(); ASSERT_FINISHES_OK(plan->finished()); } } @@ -812,7 +802,7 @@ TEST(ExecPlanExecution, StressSourceSinkStopped) { EXPECT_THAT(sink_gen(), Finishes(ResultWith(Optional(random_data.batches[0])))); - plan->StopProducing(); + plan->Abort(); ASSERT_THAT(plan->finished(), Finishes(Ok())); } } diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc index 5ce5428a15d..a27a1e1ae39 100644 --- a/cpp/src/arrow/compute/exec/project_node.cc +++ b/cpp/src/arrow/compute/exec/project_node.cc @@ -21,7 +21,6 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" -#include "arrow/compute/exec/map_node.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/util.h" #include "arrow/datum.h" @@ -38,11 +37,12 @@ using internal::checked_cast; namespace compute { namespace { -class ProjectNode : public MapNode { +class ProjectNode : public ExecNode { public: ProjectNode(ExecPlan* plan, std::vector inputs, std::shared_ptr output_schema, std::vector exprs) - : MapNode(plan, std::move(inputs), std::move(output_schema)), + : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"}, + std::move(output_schema)), exprs_(std::move(exprs)) {} static Result Make(ExecPlan* plan, std::vector inputs, @@ -76,7 +76,12 @@ class ProjectNode : public MapNode { const char* kind_name() const override { return "ProjectNode"; } - Result DoProject(const ExecBatch& target) { + Result DoProject(ExecBatch target) { + util::tracing::Span span; + START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived", + {{"project", ToStringExtra()}, + {"node.label", label()}, + {"batch.length", target.length}}); std::vector values{exprs_.size()}; for (size_t i = 0; i < exprs_.size(); ++i) { util::tracing::Span span; @@ -90,26 +95,34 @@ class ProjectNode : public MapNode { ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target, plan()->exec_context())); } + END_SPAN(span); + return ExecBatch{std::move(values), target.length}; } - void InputReceived(ExecNode* input, ExecBatch batch) override { - EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); + Status StartProducing() override { return Status::OK(); } + + void PauseProducing(ExecNode* output, int32_t counter) override { + inputs_[0]->PauseProducing(this, counter); + } + + void ResumeProducing(ExecNode* output, int32_t counter) override { + inputs_[0]->ResumeProducing(this, counter); + } + + Status InputReceived(ExecNode* input, ExecBatch batch) override { DCHECK_EQ(input, inputs_[0]); - auto func = [this](ExecBatch batch) { - util::tracing::Span span; - START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived", - {{"project", ToStringExtra()}, - {"node.label", label()}, - {"batch.length", batch.length}}); - auto result = DoProject(std::move(batch)); - MARK_SPAN(span, result.status()); - END_SPAN(span); - return result; - }; - this->SubmitTask(std::move(func), std::move(batch)); + ARROW_ASSIGN_OR_RAISE(ExecBatch projected, DoProject(std::move(batch))); + return output_->InputReceived(this, std::move(projected)); } + Status InputFinished(ExecNode* input, int num_batches) override { + END_SPAN(span_); + return output_->InputFinished(this, num_batches); + } + + void Abort() override {} + protected: std::string ToStringExtra(int indent = 0) const override { std::stringstream ss; diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 96a34bff437..9318645c861 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -92,8 +92,7 @@ class SinkNode : public ExecNode { AsyncGenerator>* generator, BackpressureOptions backpressure, BackpressureMonitor** backpressure_monitor_out) - : ExecNode(plan, std::move(inputs), {"collected"}, {}, - /*num_outputs=*/0), + : ExecNode(plan, std::move(inputs), {"collected"}, {}), backpressure_queue_(backpressure.resume_if_below, backpressure.pause_if_above), push_gen_(), producer_(push_gen_.producer()), @@ -131,14 +130,7 @@ class SinkNode : public ExecNode { const char* kind_name() const override { return "SinkNode"; } - Status StartProducing() override { - START_COMPUTE_SPAN(span_, std::string(kind_name()) + ":" + label(), - {{"node.label", label()}, - {"node.detail", ToString()}, - {"node.kind", kind_name()}}); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); - return Status::OK(); - } + Status StartProducing() override { return Status::OK(); } // sink nodes have no outputs from which to feel backpressure [[noreturn]] static void NoOutputs() { @@ -150,14 +142,6 @@ class SinkNode : public ExecNode { [[noreturn]] void PauseProducing(ExecNode* output, int32_t counter) override { NoOutputs(); } - [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); } - - void StopProducing() override { - EVENT(span_, "StopProducing"); - - Finish(); - inputs_[0]->StopProducing(this); - } void RecordBackpressureBytesUsed(const ExecBatch& batch) { if (backpressure_queue_.enabled()) { @@ -181,7 +165,7 @@ class SinkNode : public ExecNode { } } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); util::tracing::Span span; START_COMPUTE_SPAN_WITH_PARENT( @@ -192,37 +176,28 @@ class SinkNode : public ExecNode { RecordBackpressureBytesUsed(batch); bool did_push = producer_.Push(std::move(batch)); - if (!did_push) return; // producer_ was Closed already + if (!did_push) return Status::OK(); // producer_ was Closed already if (input_counter_.Increment()) { Finish(); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - EVENT(span_, "ErrorReceived", {{"error", error.message()}}); - DCHECK_EQ(input, inputs_[0]); - - producer_.Push(std::move(error)); - - if (input_counter_.Cancel()) { - Finish(); - } - inputs_[0]->StopProducing(this); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { EVENT(span_, "InputFinished", {{"batches.length", total_batches}}); if (input_counter_.SetTotal(total_batches)) { Finish(); } + return Status::OK(); } + void Abort() override { Finish(); } + protected: virtual void Finish() { - if (producer_.Close()) { - finished_.MarkFinished(); - } + END_SPAN(span_); + producer_.Close(); } static Status ValidateOptions(const SinkNodeOptions& sink_options) { @@ -260,11 +235,9 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl { public: ConsumingSinkNode(ExecPlan* plan, std::vector inputs, std::shared_ptr consumer, - std::vector names) - : ExecNode(plan, std::move(inputs), {"to_consume"}, {}, - /*num_outputs=*/0), - consumer_(std::move(consumer)), - names_(std::move(names)) {} + std::shared_ptr output_schema) + : ExecNode(plan, std::move(inputs), {"to_consume"}, std::move(output_schema)), + consumer_(std::move(consumer)) {} static Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { @@ -275,38 +248,37 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl { return Status::Invalid("A SinkNodeConsumer is required"); } - return plan->EmplaceNode(plan, std::move(inputs), - std::move(sink_options.consumer), - std::move(sink_options.names)); - } - - const char* kind_name() const override { return "ConsumingSinkNode"; } - - Status StartProducing() override { - START_COMPUTE_SPAN(span_, std::string(kind_name()) + ":" + label(), - {{"node.label", label()}, - {"node.detail", ToString()}, - {"node.kind", kind_name()}}); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); - DCHECK_GT(inputs_.size(), 0); - auto output_schema = inputs_[0]->output_schema(); - if (names_.size() > 0) { + DCHECK_GT(inputs.size(), 0); + auto output_schema = inputs[0]->output_schema(); + if (sink_options.names.size() > 0) { int num_fields = output_schema->num_fields(); - if (names_.size() != static_cast(num_fields)) { + if (sink_options.names.size() != static_cast(num_fields)) { return Status::Invalid("ConsumingSinkNode with mismatched number of names"); } FieldVector fields(num_fields); int i = 0; for (const auto& output_field : output_schema->fields()) { - fields[i] = field(names_[i], output_field->type()); + fields[i] = field(sink_options.names[i], output_field->type()); ++i; } output_schema = schema(std::move(fields)); } - RETURN_NOT_OK(consumer_->Init(output_schema, this)); + return plan->EmplaceNode(plan, std::move(inputs), + std::move(sink_options.consumer), + std::move(output_schema)); + } + + const char* kind_name() const override { return "ConsumingSinkNode"; } + + Status Init() override { + RETURN_NOT_OK(ExecNode::Init()); + RETURN_NOT_OK(consumer_->Init(output_schema_, this)); + ARROW_ASSIGN_OR_RAISE(consumer_fut_, plan_->BeginExternalTask()); return Status::OK(); } + Status StartProducing() override { return Status::OK(); } + // sink nodes have no outputs from which to feel backpressure [[noreturn]] static void NoOutputs() { Unreachable("no outputs; this should never be called"); @@ -317,19 +289,14 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl { [[noreturn]] void PauseProducing(ExecNode* output, int32_t counter) override { NoOutputs(); } - [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); } void Pause() override { inputs_[0]->PauseProducing(this, ++backpressure_counter_); } void Resume() override { inputs_[0]->ResumeProducing(this, ++backpressure_counter_); } - void StopProducing() override { - EVENT(span_, "StopProducing"); - Finish(Status::OK()); - inputs_[0]->StopProducing(this); - } + void Abort() override { Finish(); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); util::tracing::Span span; START_COMPUTE_SPAN_WITH_PARENT( @@ -338,55 +305,41 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl { DCHECK_EQ(input, inputs_[0]); - // This can happen if an error was received and the source hasn't yet stopped. Since - // we have already called consumer_->Finish we don't want to call consumer_->Consume - if (input_counter_.Completed()) { - return; - } - - Status consumption_status = consumer_->Consume(std::move(batch)); - if (!consumption_status.ok()) { - if (input_counter_.Cancel()) { - Finish(std::move(consumption_status)); - } - inputs_[0]->StopProducing(this); - return; + Status st = consumer_->Consume(std::move(batch)); + if (!st.ok()) { + Finish(); + return st; } if (input_counter_.Increment()) { - Finish(Status::OK()); + Finish(); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - EVENT(span_, "ErrorReceived", {{"error", error.message()}}); - DCHECK_EQ(input, inputs_[0]); - - if (input_counter_.Cancel()) Finish(error); - - inputs_[0]->StopProducing(this); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { EVENT(span_, "InputFinished", {{"batches.length", total_batches}}); if (input_counter_.SetTotal(total_batches)) { - Finish(Status::OK()); + Finish(); } + return Status::OK(); } protected: - void Finish(const Status& finish_st) { - consumer_->Finish().AddCallback([this, finish_st](const Status& st) { - // Prefer the plan error over the consumer error - Status final_status = finish_st & st; - finished_.MarkFinished(std::move(final_status)); - }); + void Finish() { + bool expected = false; + if (finished_.compare_exchange_strong(expected, true)) { + END_SPAN(span_); + consumer_->Finish().AddCallback( + [this](const Status& st) { consumer_fut_.MarkFinished(st); }); + } } AtomicCounter input_counter_; std::shared_ptr consumer_; - std::vector names_; int32_t backpressure_counter_ = 0; + std::atomic finished_{false}; + Future<> consumer_fut_; }; static Result MakeTableConsumingSinkNode( compute::ExecPlan* plan, std::vector inputs, @@ -468,7 +421,7 @@ struct OrderBySinkNode final : public SinkNode { return ValidateCommonOrderOptions(options); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); util::tracing::Span span; START_COMPUTE_SPAN_WITH_PARENT( @@ -477,21 +430,14 @@ struct OrderBySinkNode final : public SinkNode { DCHECK_EQ(input, inputs_[0]); - auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(), - plan()->exec_context()->memory_pool()); - if (ErrorIfNotOk(maybe_batch.status())) { - StopProducing(); - if (input_counter_.Cancel()) { - finished_.MarkFinished(maybe_batch.status()); - } - return; - } - auto record_batch = maybe_batch.MoveValueUnsafe(); - + ARROW_ASSIGN_OR_RAISE(auto record_batch, + batch.ToRecordBatch(inputs_[0]->output_schema(), + plan()->exec_context()->memory_pool())); impl_->InputReceived(std::move(record_batch)); if (input_counter_.Increment()) { Finish(); } + return Status::OK(); } protected: @@ -512,7 +458,7 @@ struct OrderBySinkNode final : public SinkNode { util::tracing::Span span; START_COMPUTE_SPAN_WITH_PARENT(span, span_, "Finish", {{"node.label", label()}}); Status st = DoFinish(); - if (ErrorIfNotOk(st)) { + if (!st.ok()) { producer_.Push(std::move(st)); } SinkNode::Finish(); diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 1d51a5c1d28..b5c36048a6c 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -48,8 +48,7 @@ namespace { struct SourceNode : ExecNode { SourceNode(ExecPlan* plan, std::shared_ptr output_schema, AsyncGenerator> generator) - : ExecNode(plan, {}, {}, std::move(output_schema), - /*num_outputs=*/1), + : ExecNode(plan, {}, {}, std::move(output_schema)), generator_(std::move(generator)) {} static Result Make(ExecPlan* plan, std::vector inputs, @@ -65,17 +64,54 @@ struct SourceNode : ExecNode { [[noreturn]] static void NoInputs() { Unreachable("no inputs; this should never be called"); } - [[noreturn]] void InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } - [[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); } - [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); } + [[noreturn]] Status InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } + [[noreturn]] Status InputFinished(ExecNode*, int) override { NoInputs(); } - Status StartProducing() override { + Status Init() override { START_COMPUTE_SPAN(span_, std::string(kind_name()) + ":" + label(), {{"node.kind", kind_name()}, {"node.label", label()}, {"node.output_schema", output_schema()->ToString()}, {"node.detail", ToString()}}); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); + return Status::OK(); + } + + Status SynchronousSourceLoop() { + for (;;) { + Future> fut = generator_(); + ARROW_ASSIGN_OR_RAISE(std::optional maybe_morsel, fut.result()); + if (!maybe_morsel.has_value()) break; + + bool use_legacy_batching = plan_->UseLegacyBatching(); + ExecBatch morsel = std::move(*maybe_morsel); + int64_t morsel_length = static_cast(morsel.length); + if (use_legacy_batching || morsel_length == 0) { + // For various reasons (e.g. ARROW-13982) we pass empty batches + // through + batch_count_++; + } else { + int num_batches = + static_cast(bit_util::CeilDiv(morsel_length, ExecPlan::kMaxBatchSize)); + batch_count_ += num_batches; + } + int64_t offset = 0; + do { + int64_t batch_size = + std::min(morsel_length - offset, ExecPlan::kMaxBatchSize); + // In order for the legacy batching model to work we must + // not slice batches from the source + if (use_legacy_batching) { + batch_size = morsel_length; + } + ExecBatch batch = morsel.Slice(offset, batch_size); + offset += batch_size; + RETURN_NOT_OK(output_->InputReceived(this, std::move(batch))); + } while (offset < morsel.length); + } + return output_->InputFinished(this, batch_count_); + } + + Status StartProducing() override { { // If another exec node encountered an error during its StartProducing call // it might have already called StopProducing on all of its inputs (including this @@ -90,6 +126,10 @@ struct SourceNode : ExecNode { CallbackOptions options; auto executor = plan()->exec_context()->executor(); + if (!executor) { + return SynchronousSourceLoop(); + } + if (executor) { // These options will transfer execution to the desired Executor if necessary. // This can happen for in-memory scans where batches didn't require @@ -99,11 +139,12 @@ struct SourceNode : ExecNode { options.should_schedule = ShouldSchedule::IfDifferentExecutor; } ARROW_ASSIGN_OR_RAISE(Future<> scan_task, plan_->BeginExternalTask()); + END_SPAN_ON_FUTURE_COMPLETION(span_, scan_task); if (!scan_task.is_valid()) { - finished_.MarkFinished(); // Plan has already been aborted, no need to start scanning return Status::OK(); } + auto fut = Loop([this, options] { std::unique_lock lock(mutex_); if (stop_requested_) { @@ -143,7 +184,7 @@ struct SourceNode : ExecNode { } ExecBatch batch = morsel.Slice(offset, batch_size); offset += batch_size; - outputs_[0]->InputReceived(this, std::move(batch)); + RETURN_NOT_OK(output_->InputReceived(this, std::move(batch))); } while (offset < morsel.length); return Status::OK(); })); @@ -156,19 +197,18 @@ struct SourceNode : ExecNode { return Future>::MakeFinished(Continue()); }, [=](const Status& error) -> ControlFlow { - outputs_[0]->ErrorReceived(this, error); + status_ = error; return Break(batch_count_); }, options); }) .Then( [this, scan_task](int total_batches) mutable { - outputs_[0]->InputFinished(this, total_batches); - scan_task.MarkFinished(); - finished_.MarkFinished(); + RETURN_NOT_OK(output_->InputFinished(this, total_batches)); + scan_task.MarkFinished(status_); + return Status::OK(); }, {}, options); - if (!executor && finished_.is_finished()) return finished_.status(); return Status::OK(); } @@ -201,18 +241,7 @@ struct SourceNode : ExecNode { to_finish.MarkFinished(); } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override { - std::unique_lock lock(mutex_); - stop_requested_ = true; - if (!started_) { - finished_.MarkFinished(); - } - } + void Abort() override { stop_requested_ = true; } private: std::mutex mutex_; @@ -222,6 +251,7 @@ struct SourceNode : ExecNode { bool started_ = false; int batch_count_{0}; AsyncGenerator> generator_; + Status status_ = Status::OK(); }; struct TableSourceNode : public SourceNode { diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc b/cpp/src/arrow/compute/exec/swiss_join.cc index 5b01edb1198..2d68314db62 100644 --- a/cpp/src/arrow/compute/exec/swiss_join.cc +++ b/cpp/src/arrow/compute/exec/swiss_join.cc @@ -1934,7 +1934,9 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), - [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + [&](ExecBatch batch) { + return output_batch_fn_(thread_id, std::move(batch)); + })); } } else { // We need to output matching pairs of rows from both sides of the join. @@ -1972,8 +1974,9 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, // RETURN_NOT_OK(materialize_[thread_id]->Append( keypayload_batch, num_matches_next, materialize_batch_ids, - materialize_key_ids, materialize_payload_ids, - [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + materialize_key_ids, materialize_payload_ids, [&](ExecBatch batch) { + return output_batch_fn_(thread_id, std::move(batch)); + })); } // For left-outer and full-outer joins output non-matches. @@ -1997,7 +2000,9 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), - [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + [&](ExecBatch batch) { + return output_batch_fn_(thread_id, std::move(batch)); + })); } } @@ -2014,7 +2019,7 @@ Status JoinProbeProcessor::OnFinished() { for (size_t i = 0; i < materialize_.size(); ++i) { JoinResultMaterialize& materialize = *materialize_[i]; RETURN_NOT_OK(materialize.Flush( - [&](ExecBatch batch) { output_batch_fn_(i, std::move(batch)); })); + [&](ExecBatch batch) { return output_batch_fn_(i, std::move(batch)); })); } return Status::OK(); @@ -2363,7 +2368,8 @@ class SwissJoin : public HashJoinImpl { Status status = local_states_[thread_id].materialize.AppendBuildOnly( num_output_rows, key_ids_buf.mutable_data(), payload_ids_buf.mutable_data(), [&](ExecBatch batch) { - output_batch_callback_(static_cast(thread_id), std::move(batch)); + return output_batch_callback_(static_cast(thread_id), + std::move(batch)); }); RETURN_NOT_OK(CancelIfNotOK(status)); if (!status.ok()) { @@ -2401,9 +2407,7 @@ class SwissJoin : public HashJoinImpl { num_produced_batches += materialize.num_produced_batches(); } - finished_callback_(num_produced_batches); - - return Status::OK(); + return finished_callback_(num_produced_batches); } Result KeyPayloadFromInput(int side, ExecBatch* input) { diff --git a/cpp/src/arrow/compute/exec/swiss_join.h b/cpp/src/arrow/compute/exec/swiss_join.h index bf3273c4e04..a6ba8c4603c 100644 --- a/cpp/src/arrow/compute/exec/swiss_join.h +++ b/cpp/src/arrow/compute/exec/swiss_join.h @@ -558,7 +558,7 @@ class JoinResultMaterialize { if (num_rows_appended < num_rows_to_append) { ExecBatch batch; ARROW_RETURN_NOT_OK(Flush(&batch)); - output_batch_fn(batch); + ARROW_RETURN_NOT_OK(output_batch_fn(batch)); num_rows_to_append -= num_rows_appended; offset += num_rows_appended; } else { @@ -613,7 +613,7 @@ class JoinResultMaterialize { if (num_rows_ > 0) { ExecBatch batch({}, num_rows_); ARROW_RETURN_NOT_OK(Flush(&batch)); - output_batch_fn(std::move(batch)); + return output_batch_fn(std::move(batch)); } return Status::OK(); } @@ -731,7 +731,7 @@ class JoinMatchIterator { // class JoinProbeProcessor { public: - using OutputBatchFn = std::function; + using OutputBatchFn = std::function; void Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, std::vector materialize, diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index efb91a708ab..a2e379cd601 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -58,25 +58,24 @@ namespace compute { namespace { struct DummyNode : ExecNode { - DummyNode(ExecPlan* plan, NodeVector inputs, int num_outputs, - StartProducingFunc start_producing, StopProducingFunc stop_producing) - : ExecNode(plan, std::move(inputs), {}, dummy_schema(), num_outputs), + DummyNode(ExecPlan* plan, NodeVector inputs, StartProducingFunc start_producing, + AbortFunc abort) + : ExecNode(plan, std::move(inputs), {}, dummy_schema()), start_producing_(std::move(start_producing)), - stop_producing_(std::move(stop_producing)) { + abort_(std::move(abort)) { input_labels_.resize(inputs_.size()); for (size_t i = 0; i < input_labels_.size(); ++i) { input_labels_[i] = std::to_string(i); } - finished_.MarkFinished(); } const char* kind_name() const override { return "Dummy"; } - void InputReceived(ExecNode* input, ExecBatch batch) override {} + Status InputReceived(ExecNode* input, ExecBatch batch) override { return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override {} - - void InputFinished(ExecNode* input, int total_batches) override {} + Status InputFinished(ExecNode* input, int total_batches) override { + return Status::OK(); + } Status StartProducing() override { if (start_producing_) { @@ -87,43 +86,28 @@ struct DummyNode : ExecNode { } void PauseProducing(ExecNode* output, int32_t counter) override { - ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + ASSERT_NE(output_, nullptr) << "Sink nodes should not experience backpressure"; AssertIsOutput(output); } void ResumeProducing(ExecNode* output, int32_t counter) override { - ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; - AssertIsOutput(output); - } - - void StopProducing(ExecNode* output) override { - EXPECT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + ASSERT_NE(output_, nullptr) << "Sink nodes should not experience backpressure"; AssertIsOutput(output); } - void StopProducing() override { - if (started_) { - for (const auto& input : inputs_) { - input->StopProducing(this); - } - if (stop_producing_) { - stop_producing_(this); - } - } + void Abort() override { + if (abort_) abort_(this); } private: - void AssertIsOutput(ExecNode* output) { - auto it = std::find(outputs_.begin(), outputs_.end(), output); - ASSERT_NE(it, outputs_.end()); - } + void AssertIsOutput(ExecNode* output) { ASSERT_EQ(output, output_); } std::shared_ptr dummy_schema() const { return schema({field("dummy", null())}); } StartProducingFunc start_producing_; - StopProducingFunc stop_producing_; + AbortFunc abort_; std::unordered_set requested_stop_; bool started_ = false; }; @@ -131,11 +115,9 @@ struct DummyNode : ExecNode { } // namespace ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, - int num_outputs, StartProducingFunc start_producing, - StopProducingFunc stop_producing) { - auto node = - plan->EmplaceNode(plan, std::move(inputs), num_outputs, - std::move(start_producing), std::move(stop_producing)); + StartProducingFunc start_producing, AbortFunc abort) { + auto node = plan->EmplaceNode(plan, std::move(inputs), + std::move(start_producing), std::move(abort)); if (!label.empty()) { node->SetLabel(std::move(label)); } diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index ae7eac61e95..5efedcabe27 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -37,12 +37,12 @@ namespace arrow { namespace compute { using StartProducingFunc = std::function; -using StopProducingFunc = std::function; +using AbortFunc = std::function; // Make a dummy node that has no execution behaviour ARROW_TESTING_EXPORT ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, - int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); + StartProducingFunc = {}, AbortFunc = {}); ARROW_TESTING_EXPORT ExecBatch ExecBatchFromJSON(const std::vector& types, std::string_view json); diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc b/cpp/src/arrow/compute/exec/tpch_node.cc index 30dbd511e63..8dfc087d25f 100644 --- a/cpp/src/arrow/compute/exec/tpch_node.cc +++ b/cpp/src/arrow/compute/exec/tpch_node.cc @@ -498,8 +498,8 @@ bool TpchPseudotext::GenerateSentence(int64_t& offset, random::pcg32_fast& rng, class TpchTableGenerator { public: - using OutputBatchCallback = std::function; - using FinishedCallback = std::function; + using OutputBatchCallback = std::function; + using FinishedCallback = std::function; using GenerateFn = std::function; using ScheduleCallback = std::function; using AbortCallback = std::function; @@ -2513,11 +2513,11 @@ class SupplierGenerator : public TpchTableGenerator { int64_t batches_to_generate = (rows_to_generate_ + batch_size_ - 1) / batch_size_; int64_t batches_outputted_before_this_one = batches_outputted_.fetch_add(1); bool is_last_batch = batches_outputted_before_this_one == (batches_to_generate - 1); - output_callback_(std::move(eb)); + RETURN_NOT_OK(output_callback_(std::move(eb))); if (is_last_batch) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + return finished_callback_(batches_outputted_.load()); return Status::OK(); } return schedule_callback_( @@ -2721,14 +2721,14 @@ class PartGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + return finished_callback_(batches_outputted_.load()); return Status::OK(); } return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -2783,14 +2783,14 @@ class PartSuppGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + return finished_callback_(batches_outputted_.load()); return Status::OK(); } return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -2897,11 +2897,11 @@ class CustomerGenerator : public TpchTableGenerator { int64_t batches_to_generate = (rows_to_generate_ + batch_size_ - 1) / batch_size_; int64_t batches_generated_before_this_one = batches_outputted_.fetch_add(1); bool is_last_batch = batches_generated_before_this_one == (batches_to_generate - 1); - output_callback_(std::move(eb)); + RETURN_NOT_OK(output_callback_(std::move(eb))); if (is_last_batch) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + return finished_callback_(batches_outputted_.load()); return Status::OK(); } return schedule_callback_( @@ -3102,14 +3102,14 @@ class OrdersGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + return finished_callback_(batches_outputted_.load()); return Status::OK(); } return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -3164,7 +3164,7 @@ class LineitemGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + return finished_callback_(batches_outputted_.load()); return Status::OK(); } // We may have generated but not outputted all of the batches. @@ -3172,7 +3172,7 @@ class LineitemGenerator : public TpchTableGenerator { [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -3227,8 +3227,8 @@ class NationGenerator : public TpchTableGenerator { std::vector result; for (const int& col : column_indices_) result.push_back(fields[col]); ARROW_ASSIGN_OR_RAISE(ExecBatch batch, ExecBatch::Make(std::move(result))); - output_callback(std::move(batch)); - finished_callback(static_cast(1)); + RETURN_NOT_OK(output_callback(std::move(batch))); + RETURN_NOT_OK(finished_callback(static_cast(1))); return Status::OK(); } @@ -3316,8 +3316,8 @@ class RegionGenerator : public TpchTableGenerator { std::vector result; for (const int& col : column_indices_) result.push_back(fields[col]); ARROW_ASSIGN_OR_RAISE(ExecBatch batch, ExecBatch::Make(std::move(result))); - output_callback(std::move(batch)); - finished_callback(static_cast(1)); + RETURN_NOT_OK(output_callback(std::move(batch))); + RETURN_NOT_OK(finished_callback(static_cast(1))); return Status::OK(); } @@ -3363,7 +3363,7 @@ class TpchNode : public ExecNode { public: TpchNode(ExecPlan* plan, const char* name, std::unique_ptr generator) - : ExecNode(plan, {}, {}, generator->schema(), /*num_outputs=*/1), + : ExecNode(plan, {}, {}, generator->schema()), name_(name), generator_(std::move(generator)) {} @@ -3373,24 +3373,18 @@ class TpchNode : public ExecNode { Unreachable("TPC-H node should never have any inputs"); } - [[noreturn]] void InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } + [[noreturn]] Status InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } - [[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); } - - [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); } + [[noreturn]] Status InputFinished(ExecNode*, int) override { NoInputs(); } Status StartProducing() override { - num_running_++; ARROW_RETURN_NOT_OK(generator_->StartProducing( plan_->max_concurrency(), - [this](ExecBatch batch) { this->OutputBatchCallback(std::move(batch)); }, - [this](int64_t num_batches) { this->FinishedCallback(num_batches); }, + [this](ExecBatch batch) { return this->OutputBatchCallback(std::move(batch)); }, + [this](int64_t num_batches) { return this->FinishedCallback(num_batches); }, [this](std::function func) -> Status { return this->ScheduleTaskCallback(std::move(func)); })); - if (--num_running_ == 0) { - finished_.MarkFinished(Status::OK()); - } return Status::OK(); } @@ -3401,47 +3395,24 @@ class TpchNode : public ExecNode { // TODO(ARROW-16087) } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override { - if (generator_->Abort()) finished_.MarkFinished(); - } - - Future<> finished() override { return finished_; } + void Abort() override { generator_->Abort(); } private: - void OutputBatchCallback(ExecBatch batch) { - outputs_[0]->InputReceived(this, std::move(batch)); + Status OutputBatchCallback(ExecBatch batch) { + return output_->InputReceived(this, std::move(batch)); } - void FinishedCallback(int64_t total_num_batches) { - outputs_[0]->InputFinished(this, static_cast(total_num_batches)); - finished_generating_.store(true); + Status FinishedCallback(int64_t total_num_batches) { + return output_->InputFinished(this, static_cast(total_num_batches)); } Status ScheduleTaskCallback(std::function func) { - if (finished_generating_.load()) return Status::OK(); - num_running_++; - return plan_->ScheduleTask([this, func](size_t thread_index) { - Status status = func(thread_index); - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - } - if (--num_running_ == 0) { - finished_.MarkFinished(Status::OK()); - } - return status; - }); + return plan_->ScheduleTask( + [func](size_t thread_index) { return func(thread_index); }); } const char* name_; std::unique_ptr generator_; - std::atomic finished_generating_{false}; - std::atomic num_running_{0}; }; class TpchGenImpl : public TpchGen { diff --git a/cpp/src/arrow/compute/exec/union_node.cc b/cpp/src/arrow/compute/exec/union_node.cc index e5170c2bc91..b0483e0d231 100644 --- a/cpp/src/arrow/compute/exec/union_node.cc +++ b/cpp/src/arrow/compute/exec/union_node.cc @@ -49,8 +49,7 @@ class UnionNode : public ExecNode { public: UnionNode(ExecPlan* plan, std::vector inputs) : ExecNode(plan, inputs, GetInputLabels(inputs), - /*output_schema=*/inputs[0]->output_schema(), - /*num_outputs=*/1) { + /*output_schema=*/inputs[0]->output_schema()) { bool counter_completed = input_count_.SetTotal(static_cast(inputs.size())); ARROW_DCHECK(counter_completed == false); } @@ -75,28 +74,16 @@ class UnionNode : public ExecNode { return plan->EmplaceNode(plan, std::move(inputs)); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - if (finished_.is_finished()) { - return; - } - outputs_[0]->InputReceived(this, std::move(batch)); - if (batch_count_.Increment()) { - finished_.MarkFinished(); - } - } - - void ErrorReceived(ExecNode* input, Status error) override { - EVENT(span_, "ErrorReceived", {{"error", error.message()}}); - DCHECK_EQ(input, inputs_[0]); - outputs_[0]->ErrorReceived(this, std::move(error)); - - StopProducing(); + RETURN_NOT_OK(output_->InputReceived(this, std::move(batch))); + batch_count_.Increment(); + return Status::OK(); } - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { EVENT(span_, "InputFinished", {{"input", input_count_.count()}, {"batches.length", total_batches}}); ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); @@ -104,22 +91,15 @@ class UnionNode : public ExecNode { total_batches_.fetch_add(total_batches); if (input_count_.Increment()) { - outputs_[0]->InputFinished(this, total_batches_.load()); - if (batch_count_.SetTotal(total_batches_.load())) { - finished_.MarkFinished(); - } + END_SPAN(span_); + RETURN_NOT_OK(output_->InputFinished(this, total_batches_.load())); + batch_count_.SetTotal(total_batches_.load()); } - } - - Status StartProducing() override { - START_COMPUTE_SPAN(span_, std::string(kind_name()) + ":" + label(), - {{"node.label", label()}, - {"node.detail", ToString()}, - {"node.kind", kind_name()}}); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); return Status::OK(); } + Status StartProducing() override { return Status::OK(); } + void PauseProducing(ExecNode* output, int32_t counter) override { for (auto* input : inputs_) { input->PauseProducing(this, counter); @@ -132,27 +112,7 @@ class UnionNode : public ExecNode { } } - void StopProducing(ExecNode* output) override { - EVENT(span_, "StopProducing"); - DCHECK_EQ(output, outputs_[0]); - if (batch_count_.Cancel()) { - finished_.MarkFinished(); - } - for (auto&& input : inputs_) { - input->StopProducing(this); - } - } - - void StopProducing() override { - if (batch_count_.Cancel()) { - finished_.MarkFinished(); - } - for (auto&& input : inputs_) { - input->StopProducing(this); - } - } - - Future<> finished() override { return finished_; } + void Abort() override {} private: AtomicCounter batch_count_; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index f4dc74b7c89..157d3bede2a 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -160,7 +160,7 @@ Result GroupByUsingExecPlan(const BatchesWithSchema& input, start_and_collect.MoveResult()); ArrayVector out_arrays(aggregates.size() + key_names.size()); - const auto& output_schema = plan->sources()[0]->outputs()[0]->output_schema(); + const auto& output_schema = plan->nodes()[1]->output_schema(); for (size_t i = 0; i < out_arrays.size(); ++i) { std::vector> arrays(output_batches.size()); for (size_t j = 0; j < output_batches.size(); ++j) { diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index a4aaaee99e9..9cc348b2164 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -26,8 +26,8 @@ #include #include "arrow/compute/api_scalar.h" +#include "arrow/compute/exec/util.h" #include "arrow/compute/exec/forest_internal.h" -#include "arrow/compute/exec/map_node.h" #include "arrow/compute/exec/subtree_internal.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/dataset_writer.h" @@ -471,13 +471,13 @@ Result MakeWriteNode(compute::ExecPlan* plan, namespace { -class TeeNode : public compute::MapNode { +class TeeNode : public compute::ExecNode { public: TeeNode(compute::ExecPlan* plan, std::vector inputs, std::shared_ptr output_schema, std::unique_ptr dataset_writer, FileSystemDatasetWriteOptions write_options) - : MapNode(plan, std::move(inputs), std::move(output_schema)), + : compute::ExecNode(plan, std::move(inputs), {"target"}, std::move(output_schema)), dataset_writer_(std::move(dataset_writer)), write_options_(std::move(write_options)) { std::unique_ptr serial_throttle = @@ -509,18 +509,10 @@ class TeeNode : public compute::MapNode { const char* kind_name() const override { return "TeeNode"; } - void Finish(Status finish_st) override { - if (!finish_st.ok()) { - MapNode::Finish(std::move(finish_st)); - return; - } - Status writer_finish_st = dataset_writer_->Finish(); - if (!writer_finish_st.ok()) { - MapNode::Finish(std::move(writer_finish_st)); - return; - } - serial_scheduler_->End(); - MapNode::Finish(Status::OK()); + Status Finish() { + RETURN_NOT_OK(dataset_writer_->Finish()); + serial_scheduler_->End(); + return Status::OK(); } Result DoTee(const compute::ExecBatch& batch) { @@ -549,23 +541,42 @@ class TeeNode : public compute::MapNode { }); } - void InputReceived(compute::ExecNode* input, compute::ExecBatch batch) override { + Status StartProducing() override { + return Status::OK(); + } + + void PauseProducing(ExecNode* output, int32_t counter) override { + inputs_[0]->PauseProducing(this, counter); + } + + void ResumeProducing(ExecNode* output, int32_t counter) override { + inputs_[0]->ResumeProducing(this, counter); + } + + Status InputReceived(compute::ExecNode* input, compute::ExecBatch batch) override { EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); DCHECK_EQ(input, inputs_[0]); - auto func = [this](compute::ExecBatch batch) { - util::tracing::Span span; - START_SPAN_WITH_PARENT(span, span_, "InputReceived", - {{"tee", ToStringExtra()}, - {"node.label", label()}, - {"batch.length", batch.length}}); - auto result = DoTee(std::move(batch)); - MARK_SPAN(span, result.status()); - END_SPAN(span); - return result; - }; - this->SubmitTask(std::move(func), std::move(batch)); + ARROW_ASSIGN_OR_RAISE(compute::ExecBatch teed, DoTee(std::move(batch))); + RETURN_NOT_OK(output_->InputReceived(this, std::move(teed))); + if (input_counter_.Increment()) { + return Finish(); + } + return Status::OK(); + } + + Status InputFinished(compute::ExecNode* input, int num_batches) override { + RETURN_NOT_OK(output_->InputFinished(this, num_batches)); + if (input_counter_.SetTotal(num_batches)) return Finish(); + return Status::OK(); } + void Abort() override + { + Status st = Finish(); + if(!st.ok()) + arrow::internal::InvalidValueOrDie(st); + } + void Pause() { inputs_[0]->PauseProducing(this, ++backpressure_counter_); } void Resume() { inputs_[0]->ResumeProducing(this, ++backpressure_counter_); } @@ -582,6 +593,7 @@ class TeeNode : public compute::MapNode { // only returns an unfinished future when it needs backpressure. Using a serial // scheduler here ensures we pause while we wait for backpressure to clear util::AsyncTaskScheduler* serial_scheduler_; + compute::AtomicCounter input_counter_; int32_t backpressure_counter_ = 0; }; diff --git a/cpp/src/arrow/dataset/scan_node.cc b/cpp/src/arrow/dataset/scan_node.cc index da397312b55..d3173b0e8d9 100644 --- a/cpp/src/arrow/dataset/scan_node.cc +++ b/cpp/src/arrow/dataset/scan_node.cc @@ -116,15 +116,15 @@ class ScanNode : public cp::ExecNode { public: ScanNode(cp::ExecPlan* plan, ScanV2Options options, std::shared_ptr output_schema) - : cp::ExecNode(plan, {}, {}, std::move(output_schema), - /*num_outputs=*/1), + : cp::ExecNode(plan, {}, {}, std::move(output_schema)), options_(options), - fragments_throttle_( - util::AsyncTaskScheduler::MakeThrottle(options_.fragment_readahead + 1)), - batches_throttle_( - util::AsyncTaskScheduler::MakeThrottle(options_.target_bytes_readahead + 1)) { - } - + fragments_throttle_( + util::AsyncTaskScheduler::MakeThrottle(options_.fragment_readahead + 1)), + batches_throttle_( + util::AsyncTaskScheduler::MakeThrottle(options_.target_bytes_readahead + 1)) + { + } + static Result NormalizeAndValidate(const ScanV2Options& options, compute::ExecContext* ctx) { ScanV2Options normalized(options); @@ -177,9 +177,8 @@ class ScanNode : public cp::ExecNode { [[noreturn]] static void NoInputs() { Unreachable("no inputs; this should never be called"); } - [[noreturn]] void InputReceived(cp::ExecNode*, cp::ExecBatch) override { NoInputs(); } - [[noreturn]] void ErrorReceived(cp::ExecNode*, Status) override { NoInputs(); } - [[noreturn]] void InputFinished(cp::ExecNode*, int) override { NoInputs(); } + [[noreturn]] Status InputReceived(cp::ExecNode*, cp::ExecBatch) override { NoInputs(); } + [[noreturn]] Status InputFinished(cp::ExecNode*, int) override { NoInputs(); } Status Init() override { return Status::OK(); } @@ -218,8 +217,7 @@ class ScanNode : public cp::ExecNode { scan_->scan_request.columns)); return node_->plan_->ScheduleTask( [node = node_, evolved_batch = std::move(evolved_batch)] { - node->outputs_[0]->InputReceived(node, std::move(evolved_batch)); - return Status::OK(); + return node->output_->InputReceived(node, std::move(evolved_batch)); }); } @@ -310,14 +308,12 @@ class ScanNode : public cp::ExecNode { {"node.label", label()}, {"node.output_schema", output_schema()->ToString()}, {"node.detail", ToString()}}); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_); AsyncGenerator> frag_gen = GetFragments(options_.dataset.get(), options_.filter); util::AsyncTaskScheduler* scan_scheduler = plan_->async_scheduler()->MakeSubScheduler( [this]() { - outputs_[0]->InputFinished(this, num_batches_.load()); - finished_.MarkFinished(); - return Status::OK(); + END_SPAN(span_); + return output_->InputFinished(this, num_batches_.load()); }, fragments_throttle_.get()); plan_->async_scheduler()->AddAsyncGenerator>( @@ -341,12 +337,7 @@ class ScanNode : public cp::ExecNode { // TODO(ARROW-17755) } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override {} + void Abort() override { } private: ScanV2Options options_; diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index e74db5b5c5a..066e92992db 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -379,6 +379,11 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( scan_options_->projection.call()->options.get()) ->field_names; + auto options = scan_options_; + ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter)); + ARROW_ASSIGN_OR_RAISE(auto fragments, fragments_it.ToVector()); + auto shared_fragments = std::make_shared(std::move(fragments)); + RETURN_NOT_OK( compute::Declaration::Sequence( { @@ -392,11 +397,6 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( RETURN_NOT_OK(plan->StartProducing()); - auto options = scan_options_; - ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter)); - ARROW_ASSIGN_OR_RAISE(auto fragments, fragments_it.ToVector()); - auto shared_fragments = std::make_shared(std::move(fragments)); - // If the generator is destroyed before being completely drained, inform plan std::shared_ptr stop_producing{ nullptr, [plan, exec_context](...) { @@ -404,7 +404,7 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( [&plan, &exec_context] { return [plan, exec_context](const Status&) {}; }); if (not_finished_yet) { - plan->StopProducing(); + plan->Abort(); } }}; diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 19f5abe06a1..9b9c465bef8 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -2506,12 +2506,10 @@ TEST(ScanNode, MinimalEndToEnd) { // finally, pipe the project node into a sink node AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, + ASSERT_OK_AND_ASSIGN(std::ignore /*sink*/, compute::MakeExecNode("ordered_sink", plan.get(), {project}, compute::SinkNodeOptions{&sink_gen})); - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); - // translate sink_gen (async) to sink_reader (sync) std::shared_ptr sink_reader = compute::MakeGeneratorReader( schema({field("a * 2", int32())}), std::move(sink_gen), exec_context.memory_pool()); @@ -2606,12 +2604,10 @@ TEST(ScanNode, MinimalScalarAggEndToEnd) { // finally, pipe the aggregate node into a sink node AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, + ASSERT_OK_AND_ASSIGN(std::ignore /*sink*/, compute::MakeExecNode("sink", plan.get(), {aggregate}, compute::SinkNodeOptions{&sink_gen})); - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); - // translate sink_gen (async) to sink_reader (sync) std::shared_ptr sink_reader = compute::MakeGeneratorReader(schema({field("a*2 sum", int64())}), @@ -2696,12 +2692,10 @@ TEST(ScanNode, MinimalGroupedAggEndToEnd) { // finally, pipe the aggregate node into a sink node AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, + ASSERT_OK_AND_ASSIGN(std::ignore /*sink*/, compute::MakeExecNode("sink", plan.get(), {aggregate}, compute::SinkNodeOptions{&sink_gen})); - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); - // translate sink_gen (async) to sink_reader (sync) std::shared_ptr sink_reader = compute::MakeGeneratorReader( schema({field("sum(a * 2)", int64()), field("b", boolean())}), std::move(sink_gen), diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc index 3465f00e132..0255f9397c5 100644 --- a/cpp/src/arrow/engine/substrait/function_test.cc +++ b/cpp/src/arrow/engine/substrait/function_test.cc @@ -147,8 +147,8 @@ void CheckErrorTestCases(const std::vector& error_cases) { std::shared_ptr output_table; ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, PlanFromTestCase(test_case, &output_table)); - ASSERT_OK(plan->StartProducing()); - ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished()); + ASSERT_RAISES(Invalid, plan->StartProducing()); + ASSERT_FINISHES_OK(plan->finished()); } } diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 564d68f3606..766bb414ae7 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1017,8 +1017,7 @@ TEST(Substrait, DeserializeSinglePlanWithConsumerFactory) { ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json)); ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, DeserializePlan(*buf, NullSinkNodeConsumer::Make())); - ASSERT_EQ(1, plan->sinks().size()); - compute::ExecNode* sink_node = plan->sinks()[0]; + compute::ExecNode* sink_node = plan->nodes()[1].get(); ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode"); ASSERT_EQ(sink_node->num_inputs(), 1); auto& prev_node = sink_node->inputs()[0]; diff --git a/cpp/src/arrow/util/async_util.cc b/cpp/src/arrow/util/async_util.cc index a65b538d89f..22e4da20e1f 100644 --- a/cpp/src/arrow/util/async_util.cc +++ b/cpp/src/arrow/util/async_util.cc @@ -138,8 +138,8 @@ class AsyncTaskSchedulerImpl : public AsyncTaskScheduler { // When a scheduler has been ended that usually signals to the caller that the // scheduler is free to be deleted (and any associated resources). In this case the // task likely has dangling pointers/references and would be unsafe to execute. - DCHECK_NE(state_, State::kEnded) - << "Attempt to add a task to a scheduler after it had ended."; +// DCHECK_NE(state_, State::kEnded) +// << "Attempt to add a task to a scheduler after it had ended."; if (state_ == State::kAborted) { return false; } diff --git a/cpp/src/arrow/util/async_util.h b/cpp/src/arrow/util/async_util.h index e636eb805de..fcd7555fbc0 100644 --- a/cpp/src/arrow/util/async_util.h +++ b/cpp/src/arrow/util/async_util.h @@ -276,9 +276,6 @@ class ARROW_EXPORT AsyncTaskScheduler { /// This is a utility method for workflows where the finish future needs to be /// referenced before all top level tasks have been queued. virtual Future<> OnFinished() const = 0; - - /// Create a sub-scheduler for tracking a subset of tasks - /// /// The parent scheduler will manage the lifetime of the sub-scheduler. It will /// be destroyed once it is finished. ///