diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index fd5a6bae04d..04407a8818c 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -1985,22 +1985,16 @@ garrow_execute_plan_validate(GArrowExecutePlan *plan, /** * garrow_execute_plan_start: * @plan: A #GArrowExecutePlan. - * @error: (nullable): Return location for a #GError or %NULL. * * Starts this plan. * - * Returns: %TRUE on success, %FALSE on error. - * * Since: 6.0.0 */ -gboolean -garrow_execute_plan_start(GArrowExecutePlan *plan, - GError **error) +void +garrow_execute_plan_start(GArrowExecutePlan *plan) { auto arrow_plan = garrow_execute_plan_get_raw(plan); - return garrow::check(error, - arrow_plan->StartProducing(), - "[execute-plan][start]"); + arrow_plan->StartProducing(); } /** diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index 1ac1d05258c..2b37efec88e 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -369,9 +369,8 @@ gboolean garrow_execute_plan_validate(GArrowExecutePlan *plan, GError **error); GARROW_AVAILABLE_IN_6_0 -gboolean -garrow_execute_plan_start(GArrowExecutePlan *plan, - GError **error); +void +garrow_execute_plan_start(GArrowExecutePlan *plan); GARROW_AVAILABLE_IN_6_0 void garrow_execute_plan_stop(GArrowExecutePlan *plan); diff --git a/cpp/examples/arrow/compute_register_example.cc b/cpp/examples/arrow/compute_register_example.cc index 1b96dd42220..1df34f17459 100644 --- a/cpp/examples/arrow/compute_register_example.cc +++ b/cpp/examples/arrow/compute_register_example.cc @@ -74,14 +74,11 @@ class ExampleNode : public cp::ExecNode { ExampleNode(ExecNode* input, const ExampleNodeOptions&) : ExecNode(/*plan=*/input->plan(), /*inputs=*/{input}, /*input_labels=*/{"ignored"}, - /*output_schema=*/input->output_schema(), /*num_outputs=*/1) {} + /*output_schema=*/input->output_schema()) {} const char* kind_name() const override { return "ExampleNode"; } - arrow::Status StartProducing() override { - outputs_[0]->InputFinished(this, 0); - return arrow::Status::OK(); - } + arrow::Status StartProducing() override { return output_->InputFinished(this, 0); } void ResumeProducing(ExecNode* output, int32_t counter) override { inputs_[0]->ResumeProducing(this, counter); @@ -90,14 +87,14 @@ class ExampleNode : public cp::ExecNode { inputs_[0]->PauseProducing(this, counter); } - void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); } - void StopProducing() override { inputs_[0]->StopProducing(); } + arrow::Status StopProducingImpl() override { return arrow::Status::OK(); } - void InputReceived(ExecNode* input, cp::ExecBatch batch) override {} - void ErrorReceived(ExecNode* input, arrow::Status error) override {} - void InputFinished(ExecNode* input, int total_batches) override {} - - arrow::Future<> finished() override { return inputs_[0]->finished(); } + arrow::Status InputReceived(ExecNode* input, cp::ExecBatch batch) override { + return arrow::Status::OK(); + } + arrow::Status InputFinished(ExecNode* input, int total_batches) override { + return arrow::Status::OK(); + } }; arrow::Result ExampleExecNodeFactory(cp::ExecPlan* plan, diff --git a/cpp/examples/arrow/engine_substrait_consumption.cc b/cpp/examples/arrow/engine_substrait_consumption.cc index aef189952c6..354375ce59b 100644 --- a/cpp/examples/arrow/engine_substrait_consumption.cc +++ b/cpp/examples/arrow/engine_substrait_consumption.cc @@ -162,7 +162,7 @@ arrow::Status RunSubstraitConsumer(int argc, char** argv) { // Start the plan... std::cout << std::string(50, '#') << " consuming batches:" << std::endl; - ARROW_RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); // ... and wait for it to finish ARROW_RETURN_NOT_OK(plan->finished().status()); diff --git a/cpp/examples/arrow/execution_plan_documentation_examples.cc b/cpp/examples/arrow/execution_plan_documentation_examples.cc index a72db97930c..25baaa7d0a1 100644 --- a/cpp/examples/arrow/execution_plan_documentation_examples.cc +++ b/cpp/examples/arrow/execution_plan_documentation_examples.cc @@ -531,7 +531,7 @@ arrow::Status ExecutePlanAndCollectAsTableWithCustomSink( ARROW_RETURN_NOT_OK(plan->Validate()); std::cout << "ExecPlan created : " << plan->ToString() << std::endl; // start the ExecPlan - ARROW_RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); // collect sink_reader into a Table std::shared_ptr response_table; @@ -751,7 +751,7 @@ arrow::Status TableSinkExample() { ARROW_RETURN_NOT_OK(plan->Validate()); std::cout << "ExecPlan created : " << plan->ToString() << std::endl; // start the ExecPlan - ARROW_RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); // Wait for the plan to finish auto finished = plan->finished(); diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 98f8f92cc37..b4726b4fbc9 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -69,8 +69,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode kernels, std::vector>> states) : ExecNode(plan, std::move(inputs), {"target"}, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), + /*output_schema=*/std::move(output_schema)), target_field_ids_(std::move(target_field_ids)), aggs_(std::move(aggs)), kernels_(std::move(kernels)), @@ -155,36 +154,33 @@ class ScalarAggregateNode : public ExecNode, public TracedNodequery_context()->GetThreadIndex(); - if (ErrorIfNotOk(DoConsume(ExecSpan(batch), thread_index))) return; + ARROW_RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index)); if (input_counter_.Increment()) { - ErrorIfNotOk(Finish()); + return Finish(); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - DCHECK_EQ(input, inputs_[0]); - outputs_[0]->ErrorReceived(this, std::move(error)); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { + EVENT_ON_CURRENT_SPAN("InputFinished", {{"batches.length", total_batches}}); DCHECK_EQ(input, inputs_[0]); if (input_counter_.SetTotal(total_batches)) { - ErrorIfNotOk(Finish()); + return Finish(); } + return Status::OK(); } Status StartProducing() override { NoteStartProducing(ToStringExtra()); // Scalar aggregates will only output a single batch - outputs_[0]->InputFinished(this, 1); - return Status::OK(); + return output_->InputFinished(this, 1); } void PauseProducing(ExecNode* output, int32_t counter) override { @@ -195,17 +191,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNodeResumeProducing(this, counter); } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override { - if (input_counter_.Cancel()) { - finished_.MarkFinished(); - } - inputs_[0]->StopProducing(this); - } + Status StopProducingImpl() override { return Status::OK(); } protected: std::string ToStringExtra(int indent = 0) const override { @@ -234,9 +220,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNodefinalize(&ctx, &batch.values[i])); } - outputs_[0]->InputReceived(this, std::move(batch)); - finished_.MarkFinished(); - return Status::OK(); + return output_->InputReceived(this, std::move(batch)); } const std::vector target_field_ids_; @@ -254,8 +238,7 @@ class GroupByNode : public ExecNode, public TracedNode { std::vector key_field_ids, std::vector agg_src_field_ids, std::vector aggs, std::vector agg_kernels) - : ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema), - /*num_outputs=*/1), + : ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)), key_field_ids_(std::move(key_field_ids)), agg_src_field_ids_(std::move(agg_src_field_ids)), aggs_(std::move(aggs)), @@ -263,14 +246,8 @@ class GroupByNode : public ExecNode, public TracedNode { Status Init() override { output_task_group_id_ = plan_->query_context()->RegisterTaskGroup( - [this](size_t, int64_t task_id) { - OutputNthBatch(task_id); - return Status::OK(); - }, - [this](size_t) { - finished_.MarkFinished(); - return Status::OK(); - }); + [this](size_t, int64_t task_id) { return OutputNthBatch(task_id); }, + [](size_t) { return Status::OK(); }); return Status::OK(); } @@ -455,15 +432,13 @@ class GroupByNode : public ExecNode, public TracedNode { return out_data; } - void OutputNthBatch(int64_t n) { - // bail if StopProducing was called - if (finished_.is_finished()) return; - + Status OutputNthBatch(int64_t n) { int64_t batch_size = output_batch_size(); - outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size)); + return output_->InputReceived(this, out_data_.Slice(batch_size * n, batch_size)); } - Status DoOutputResult() { + Status OutputResult() { + auto scope = TraceFinish(); // To simplify merging, ensure that the first grouper is nonempty for (size_t i = 0; i < local_states_.size(); i++) { if (local_states_[i].grouper) { @@ -476,60 +451,31 @@ class GroupByNode : public ExecNode, public TracedNode { ARROW_ASSIGN_OR_RAISE(out_data_, Finalize()); int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size()); - outputs_[0]->InputFinished(this, static_cast(num_output_batches)); - Status st = - plan_->query_context()->StartTaskGroup(output_task_group_id_, num_output_batches); - if (st.IsCancelled()) { - // This means the user has cancelled/aborted the plan. We will not send any batches - // and end immediately. - finished_.MarkFinished(); - return Status::OK(); - } else { - return st; - } - return Status::OK(); + RETURN_NOT_OK(output_->InputFinished(this, static_cast(num_output_batches))); + return plan_->query_context()->StartTaskGroup(output_task_group_id_, + num_output_batches); } - void OutputResult() { - auto scope = TraceFinish(); - // If something goes wrong outputting the result we need to make sure - // we still mark finished. - Status st = DoOutputResult(); - if (!st.ok()) { - finished_.MarkFinished(st); - } - } - - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { auto scope = TraceInputReceived(batch); - // bail if StopProducing was called - if (finished_.is_finished()) return; - DCHECK_EQ(input, inputs_[0]); - if (ErrorIfNotOk(Consume(ExecSpan(batch)))) return; + ARROW_RETURN_NOT_OK(Consume(ExecSpan(batch))); if (input_counter_.Increment()) { - OutputResult(); + return OutputResult(); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - DCHECK_EQ(input, inputs_[0]); - - outputs_[0]->ErrorReceived(this, std::move(error)); - } - - void InputFinished(ExecNode* input, int total_batches) override { - // bail if StopProducing was called - if (finished_.is_finished()) return; - + Status InputFinished(ExecNode* input, int total_batches) override { DCHECK_EQ(input, inputs_[0]); if (input_counter_.SetTotal(total_batches)) { - OutputResult(); + return OutputResult(); } + return Status::OK(); } Status StartProducing() override { @@ -548,16 +494,7 @@ class GroupByNode : public ExecNode, public TracedNode { // Without spillover there is way to handle backpressure in this node } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - - if (input_counter_.Cancel()) { - finished_.MarkFinished(); - } - inputs_[0]->StopProducing(this); - } - - void StopProducing() override { StopProducing(outputs_[0]); } + Status StopProducingImpl() override { return Status::OK(); } protected: std::string ToStringExtra(int indent = 0) const override { diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 0ad2a951f11..f53dc07ac59 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -1100,15 +1100,18 @@ class AsofJoinNode : public ExecNode { ~Defer() noexcept { callable(); } }; - void EndFromProcessThread() { + void EndFromProcessThread(Status st = Status::OK()) { // We must spawn a new task to transfer off the process thread when // marking this finished. Otherwise there is a chance that doing so could // mark the plan finished which may destroy the plan which will destroy this // node which will cause us to join on ourselves. - ErrorIfNotOk(plan_->query_context()->executor()->Spawn([this] { - Defer cleanup([this]() { finished_.MarkFinished(); }); - outputs_[0]->InputFinished(this, batches_produced_); - })); + ARROW_UNUSED( + plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable { + Defer cleanup([this, &st]() { process_task_.MarkFinished(st); }); + if (st.ok()) { + st = output_->InputFinished(this, batches_produced_); + } + })); } bool CheckEnded() { @@ -1134,10 +1137,12 @@ class AsofJoinNode : public ExecNode { if (!out_rb) break; ++batches_produced_; ExecBatch out_b(*out_rb); - outputs_[0]->InputReceived(this, std::move(out_b)); + Status st = output_->InputReceived(this, std::move(out_b)); + if (!st.ok()) { + EndFromProcessThread(std::move(st)); + } } else { - ErrorIfNotOk(result.status()); - EndFromProcessThread(); + EndFromProcessThread(result.status()); return false; } } @@ -1453,24 +1458,19 @@ class AsofJoinNode : public ExecNode { const char* kind_name() const override { return "AsofJoinNode"; } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { // Get the input ARROW_DCHECK(std_has(inputs_, input)); size_t k = std_find(inputs_, input) - inputs_.begin(); // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - Status st = state_.at(k)->Push(rb); - if (!st.ok()) { - ErrorReceived(input, st); - return; - } + ARROW_RETURN_NOT_OK(state_.at(k)->Push(rb)); process_.Push(true); + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - outputs_[0]->ErrorReceived(this, std::move(error)); - } - void InputFinished(ExecNode* input, int total_batches) override { + + Status InputFinished(ExecNode* input, int total_batches) override { { std::lock_guard guard(gate_); ARROW_DCHECK(std_has(inputs_, input)); @@ -1482,19 +1482,28 @@ class AsofJoinNode : public ExecNode { // know whether the RHS of the join is up-to-date until we know that the table is // finished. process_.Push(true); + return Status::OK(); } - Status StartProducing() override { return Status::OK(); } + + Status StartProducing() override { + ARROW_ASSIGN_OR_RAISE(process_task_, plan_->query_context()->BeginExternalTask( + "AsofJoinNode::ProcessThread")); + if (!process_task_.is_valid()) { + // Plan has already aborted. Do not start process thread + return Status::OK(); + } + process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this); + return Status::OK(); + } + void PauseProducing(ExecNode* output, int32_t counter) override {} void ResumeProducing(ExecNode* output, int32_t counter) override {} - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - void StopProducing() override { + + Status StopProducingImpl() override { process_.Clear(); process_.Push(false); + return Status::OK(); } - arrow::Future<> finished() override { return finished_; } private: std::vector indices_of_on_key_; @@ -1515,6 +1524,7 @@ class AsofJoinNode : public ExecNode { ConcurrentQueue process_; // Worker thread std::thread process_thread_; + Future<> process_task_; // In-progress batches produced int batches_produced_ = 0; @@ -1528,8 +1538,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector> key_hashers, bool must_hash, bool may_rehash) : ExecNode(plan, inputs, input_labels, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), + /*output_schema=*/std::move(output_schema)), indices_of_on_key_(std::move(indices_of_on_key)), indices_of_by_key_(std::move(indices_of_by_key)), key_hashers_(std::move(key_hashers)), @@ -1538,7 +1547,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, tolerance_(tolerance), backpressure_counter_(1), process_(), - process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) {} + process_thread_() {} namespace internal { void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index a1405265372..376c16bb909 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -277,12 +277,9 @@ void DoInvalidPlanTest(const BatchesWithSchema& l_batches, "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); if (fail_on_plan_creation) { - AsyncGenerator> sink_gen; - ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) - .AddToPlan(plan.get())); - EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid, - ::testing::HasSubstr(expected_error_str), - StartAndCollect(plan.get(), sink_gen)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr(expected_error_str), + DeclarationToStatus(std::move(join), /*use_threads=*/false)); } else { EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_error_str), join.AddToPlan(plan.get())); diff --git a/cpp/src/arrow/compute/exec/benchmark_util.cc b/cpp/src/arrow/compute/exec/benchmark_util.cc index 3c4dda2992a..a3cd86d26d2 100644 --- a/cpp/src/arrow/compute/exec/benchmark_util.cc +++ b/cpp/src/arrow/compute/exec/benchmark_util.cc @@ -66,11 +66,11 @@ Status BenchmarkIsolatedNodeOverhead(benchmark::State& state, int task_group_id = scheduler->RegisterTaskGroup( [&](size_t thread_id, int64_t task_id) { - node->InputReceived(source_node, data.batches[task_id]); - return Status::OK(); + return node->InputReceived(source_node, data.batches[task_id]); }, [&](size_t thread_id) { - node->InputFinished(source_node, static_cast(data.batches.size())); + RETURN_NOT_OK( + node->InputFinished(source_node, static_cast(data.batches.size()))); std::unique_lock lk(mutex); all_tasks_finished_cv.notify_one(); return Status::OK(); @@ -95,9 +95,6 @@ Status BenchmarkIsolatedNodeOverhead(benchmark::State& state, ARROW_RETURN_NOT_OK( scheduler->StartTaskGroup(thread_indexer(), task_group_id, num_batches)); all_tasks_finished_cv.wait(lk); - if (!node->finished().is_finished()) { - return Status::Invalid("All tasks were finsihed but the node was not finished"); - } } state.counters["rows_per_second"] = benchmark::Counter( static_cast(state.iterations() * num_batches * batch_size), diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 19081267da7..896eafb58c3 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -28,6 +28,7 @@ #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/query_context.h" #include "arrow/compute/exec/task_util.h" +#include "arrow/compute/exec/util.h" #include "arrow/compute/registry.h" #include "arrow/datum.h" #include "arrow/record_batch.h" @@ -41,6 +42,8 @@ #include "arrow/util/tracing_internal.h" #include "arrow/util/vector.h" +using namespace std::string_view_literals; // NOLINT + namespace arrow { using internal::checked_cast; @@ -67,16 +70,13 @@ struct ExecPlanImpl : public ExecPlan { } } + const NodeVector& nodes() const { return node_ptrs_; } + ExecNode* AddNode(std::unique_ptr node) { if (node->label().empty()) { node->SetLabel(ToChars(auto_label_counter_++)); } - if (node->num_inputs() == 0) { - sources_.push_back(node.get()); - } - if (node->num_outputs() == 0) { - sinks_.push_back(node.get()); - } + node_ptrs_.push_back(node.get()); nodes_.push_back(std::move(node)); return nodes_.back().get(); } @@ -91,18 +91,28 @@ struct ExecPlanImpl : public ExecPlan { return Status::OK(); } - Status StartProducing() { + void StartProducing() { + if (finished_.is_finished()) { + finished_ = Future<>::MakeFinished( + Status::Invalid("StartProducing called after plan had already finished")); + return; + } if (started_) { - return Status::Invalid("restarted ExecPlan"); + finished_.MarkFinished( + Status::Invalid("StartProducing called on a plan that had already started.")); + return; } if (query_context_.exec_context()->executor() == nullptr) { - return Status::Invalid( + finished_.MarkFinished(Status::Invalid( "An exec plan must have an executor for CPU tasks. To run without threads use " "a SerialExeuctor (the arrow::compute::DeclarationTo... methods should take " - "care of this for you and are an easier way to execute an ExecPlan.)"); + "care of this for you and are an easier way to execute an ExecPlan.)")); + return; } if (query_context_.io_context()->executor() == nullptr) { - return Status::Invalid("An exec plan must have an I/O executor for I/O tasks."); + finished_.MarkFinished( + Status::Invalid("An exec plan must have an I/O executor for I/O tasks.")); + return; } started_ = true; @@ -130,18 +140,9 @@ struct ExecPlanImpl : public ExecPlan { }); } #endif - // TODO(weston) The entire concept of ExecNode::finished() will hopefully go - // away soon (or at least be replaced by a sub-scheduler to facilitate OT) for (auto& n : nodes_) { RETURN_NOT_OK(n->Init()); } - for (auto& n : nodes_) { - std::string qualified_label = std::string(n->kind_name()) + ":" + n->label(); - std::string wait_for_finish = - "ExecPlan::WaitForFinish(" + qualified_label + ")"; - async_scheduler->AddSimpleTask([&] { return n->finished(); }, - std::move(wait_for_finish)); - } ctx->scheduler()->RegisterEnd(); int num_threads = 1; @@ -155,7 +156,8 @@ struct ExecPlanImpl : public ExecPlan { [ctx](std::function fn) -> Status { // TODO(weston) add names to synchronous scheduler so we can use something // better than sync-scheduler-task here - return ctx->ScheduleTask(std::move(fn), "sync-scheduler-task"); + ctx->ScheduleTask(std::move(fn), "sync-scheduler-task"); + return Status::OK(); }, /*concurrent_tasks=*/2 * num_threads, sync_execution)); @@ -175,32 +177,39 @@ struct ExecPlanImpl : public ExecPlan { bool expected = false; if (stopped_.compare_exchange_strong(expected, true)) { StopProducingImpl(it.base(), sorted_nodes_.end()); - for (NodeVector::iterator fw_it = sorted_nodes_.begin(); - fw_it != it.base(); ++fw_it) { - Future<> fut = (*fw_it)->finished(); - if (!fut.is_finished()) fut.MarkFinished(); - } } return st; } } return st; }, - [this](const Status& st) { StopProducing(); }); - scheduler_finished.AddCallback( - [this](const Status& st) { finished_.MarkFinished(st); }); - // TODO(weston) Do we really need to return status here? Could we change this return - // to void? - if (finished_.is_finished()) { - return finished_.status(); - } else { - return Status::OK(); - } + [this](const Status& st) { + // If an error occurs we call StopProducing. The scheduler will already have + // stopped scheduling new tasks at this point. However, any nodes that are + // dealing with external tasks will need to trigger those external tasks to end + // early. + StopProducing(); + }); + scheduler_finished.AddCallback([this](const Status& st) { + if (st.ok()) { + if (stopped_.load()) { + finished_.MarkFinished(Status::Cancelled("Plan was cancelled early.")); + } else { + finished_.MarkFinished(); + } + } else { + finished_.MarkFinished(st); + } + }); } void StopProducing() { - DCHECK(started_) << "stopped an ExecPlan which never started"; - EVENT(span_, "ExecPlan::StopProducing"); + if (!started_) { + started_ = true; + finished_.MarkFinished(Status::Invalid( + "StopProducing was called before StartProducing. The plan never ran.")); + } + EVENT(span_, "StopProducing"); bool expected = false; if (stopped_.compare_exchange_strong(expected, true)) { query_context()->scheduler()->Abort( @@ -212,7 +221,17 @@ struct ExecPlanImpl : public ExecPlan { void StopProducingImpl(It begin, It end) { for (auto it = begin; it != end; ++it) { auto node = *it; - node->StopProducing(); + EVENT_ON_CURRENT_SPAN( + "StopProducing:" + node->label(), + {{"node.label", node->label()}, {"node.kind_name", node->kind_name()}}); + Status st = node->StopProducing(); + if (!st.ok()) { + // If an error occurs during StopProducing then we submit a task to fail. If we + // have already aborted then this will be ignored. This way the failing status + // will get communicated to finished_. + query_context()->async_scheduler()->AddSimpleTask( + [st] { return st; }, "ExecPlan::StopProducingErrorReporter"sv); + } } } @@ -314,7 +333,7 @@ struct ExecPlanImpl : public ExecPlan { bool started_ = false; std::atomic stopped_{false}; std::vector> nodes_; - NodeVector sources_, sinks_; + NodeVector node_ptrs_; NodeVector sorted_nodes_; uint32_t auto_label_counter_ = 0; util::tracing::Span span_; @@ -380,17 +399,15 @@ ExecNode* ExecPlan::AddNode(std::unique_ptr node) { return ToDerived(this)->AddNode(std::move(node)); } -const ExecPlan::NodeVector& ExecPlan::sources() const { - return ToDerived(this)->sources_; -} - -const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; } - QueryContext* ExecPlan::query_context() { return &ToDerived(this)->query_context_; } +const ExecPlanImpl::NodeVector& ExecPlan::nodes() const { + return ToDerived(this)->nodes(); +} + Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } -Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } +void ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); } @@ -406,14 +423,17 @@ std::string ExecPlan::ToString() const { return ToDerived(this)->ToString(); } ExecNode::ExecNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - std::shared_ptr output_schema, int num_outputs) - : plan_(plan), + std::shared_ptr output_schema) + : stopped_(false), + plan_(plan), inputs_(std::move(inputs)), input_labels_(std::move(input_labels)), - output_schema_(std::move(output_schema)), - num_outputs_(num_outputs) { + output_schema_(std::move(output_schema)) { for (auto input : inputs_) { - input->outputs_.push_back(this); + DCHECK_NE(input, nullptr) << " null input"; + DCHECK_EQ(input->output_, nullptr) << " attempt to add a second output to a node"; + DCHECK(!input->is_sink()) << " attempt to add a sink node as input"; + input->output_ = this; } } @@ -425,15 +445,18 @@ Status ExecNode::Validate() const { num_inputs(), ", actual ", input_labels_.size(), ")"); } - if (static_cast(outputs_.size()) != num_outputs_) { - return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ", - num_outputs(), ", actual ", outputs_.size(), ")"); - } - - for (auto out : outputs_) { - auto input_index = GetNodeIndex(out->inputs(), this); + if (is_sink()) { + if (output_ != nullptr) { + return Status::Invalid("Sink node, '", label(), "' has an output"); + } + return Status::OK(); + } else { + if (output_ == nullptr) { + return Status::Invalid("No output for node, '", label(), "'"); + } + auto input_index = GetNodeIndex(output_->inputs(), this); if (!input_index) { - return Status::Invalid("Node '", label(), "' outputs to node '", out->label(), + return Status::Invalid("Node '", label(), "' outputs to node '", output_->label(), "' but is not listed as an input."); } } @@ -441,6 +464,17 @@ Status ExecNode::Validate() const { return Status::OK(); } +Status ExecNode::StopProducing() { + bool expected = false; + if (stopped_.compare_exchange_strong(expected, true)) { + ARROW_RETURN_NOT_OK(StopProducingImpl()); + for (auto* input : inputs_) { + ARROW_RETURN_NOT_OK(input->StopProducing()); + } + } + return Status::OK(); +} + std::string ExecNode::ToString(int indent) const { std::stringstream ss; @@ -462,15 +496,6 @@ std::string ExecNode::ToString(int indent) const { std::string ExecNode::ToStringExtra(int indent) const { return ""; } -bool ExecNode::ErrorIfNotOk(Status status) { - if (status.ok()) return false; - - for (auto out : outputs_) { - out->ErrorReceived(this, out == outputs_.back() ? std::move(status) : status); - } - return true; -} - std::shared_ptr MakeGeneratorReader( std::shared_ptr schema, std::function>()> gen, MemoryPool* pool) { @@ -560,7 +585,7 @@ Future> DeclarationToTableAsync(Declaration declaration, {declaration, {"table_sink", TableSinkNodeOptions(output_table.get())}}); ARROW_RETURN_NOT_OK(with_sink.AddToPlan(exec_plan.get())); ARROW_RETURN_NOT_OK(exec_plan->Validate()); - ARROW_RETURN_NOT_OK(exec_plan->StartProducing()); + exec_plan->StartProducing(); return exec_plan->finished().Then([exec_plan, output_table] { return *output_table; }); } @@ -638,7 +663,7 @@ Future DeclarationToExecBatchesAsync(Declaration declar {declaration, {"sink", SinkNodeOptions(&sink_gen, &out_schema)}}); ARROW_RETURN_NOT_OK(with_sink.AddToPlan(exec_plan.get())); ARROW_RETURN_NOT_OK(exec_plan->Validate()); - ARROW_RETURN_NOT_OK(exec_plan->StartProducing()); + exec_plan->StartProducing(); auto collected_fut = CollectAsyncGenerator(sink_gen); return AllFinished({exec_plan->finished(), Future<>(collected_fut)}) .Then([collected_fut, exec_plan, @@ -681,14 +706,14 @@ Future<> DeclarationToStatusAsync(Declaration declaration, ExecContext exec_cont ARROW_ASSIGN_OR_RAISE(std::shared_ptr exec_plan, ExecPlan::Make(exec_context)); ARROW_ASSIGN_OR_RAISE(ExecNode * last_node, declaration.AddToPlan(exec_plan.get())); - for (int i = 0; i < last_node->num_outputs(); i++) { - ARROW_RETURN_NOT_OK( + if (!last_node->is_sink()) { + Declaration null_sink = Declaration("consuming_sink", {last_node}, - ConsumingSinkNodeOptions(NullSinkNodeConsumer::Make())) - .AddToPlan(exec_plan.get())); + ConsumingSinkNodeOptions(NullSinkNodeConsumer::Make())); + ARROW_RETURN_NOT_OK(null_sink.AddToPlan(exec_plan.get())); } ARROW_RETURN_NOT_OK(exec_plan->Validate()); - ARROW_RETURN_NOT_OK(exec_plan->StartProducing()); + exec_plan->StartProducing(); // Keep the exec_plan alive until it finishes return exec_plan->finished().Then([exec_plan]() {}); } @@ -763,7 +788,8 @@ Result>> DeclarationToRecordBatchGen {declaration, {"sink", SinkNodeOptions(&converter->exec_batch_gen, &converter->schema)}}); ARROW_RETURN_NOT_OK(with_sink.AddToPlan(plan.get())); - ARROW_RETURN_NOT_OK(plan->StartProducing()); + ARROW_RETURN_NOT_OK(plan->Validate()); + plan->StartProducing(); converter->exec_plan = std::move(plan); *out_schema = converter->schema; return [conv = std::move(converter)] { return (*conv)(); }; diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 40b217273f4..0fcfb367544 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -52,6 +52,9 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { QueryContext* query_context(); + /// \brief retrieve the nodes in the plan + const NodeVector& nodes() const; + /// Make an empty exec plan static Result> Make( QueryOptions options, ExecContext exec_context = *threaded_exec_context(), @@ -79,27 +82,22 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { return out; } - /// The initial inputs - const NodeVector& sources() const; - - /// The final outputs - const NodeVector& sinks() const; - Status Validate(); /// \brief Start producing on all nodes /// /// Nodes are started in reverse topological order, such that any node /// is started before all of its inputs. - Status StartProducing(); + void StartProducing(); /// \brief Stop producing on all nodes /// - /// Nodes are stopped in topological order, such that any node - /// is stopped before all of its outputs. + /// Triggers all sources to stop producing new data. In order to cleanly stop the plan + /// will continue to run any tasks that are already in progress. The caller should + /// still wait for `finished` to complete before destroying the plan. void StopProducing(); - /// \brief A future which will be marked finished when all nodes have stopped producing. + /// \brief A future which will be marked finished when all tasks have finished. Future<> finished(); /// \brief Return whether the plan has non-empty metadata @@ -119,18 +117,20 @@ class ARROW_EXPORT ExecNode { virtual const char* kind_name() const = 0; - // The number of inputs/outputs expected by this node + // The number of inputs expected by this node int num_inputs() const { return static_cast(inputs_.size()); } - int num_outputs() const { return num_outputs_; } /// This node's predecessors in the exec plan const NodeVector& inputs() const { return inputs_; } + /// True if the plan has no output schema (is a sink) + bool is_sink() const { return !output_schema_; } + /// \brief Labels identifying the function of each input. const std::vector& input_labels() const { return input_labels_; } - /// This node's successors in the exec plan - const NodeVector& outputs() const { return outputs_; } + /// This node's successor in the exec plan + const ExecNode* output() const { return output_; } /// The datatypes for batches produced by this node const std::shared_ptr& output_schema() const { return output_schema_; } @@ -144,11 +144,11 @@ class ARROW_EXPORT ExecNode { const std::string& label() const { return label_; } void SetLabel(std::string label) { label_ = std::move(label); } - Status Validate() const; + virtual Status Validate() const; /// Upstream API: /// These functions are called by input nodes that want to inform this node - /// about an updated condition (a new input batch, an error, an impeding + /// about an updated condition (a new input batch or an impending /// end of stream). /// /// Implementation rules: @@ -159,17 +159,21 @@ class ARROW_EXPORT ExecNode { /// and StopProducing() /// Transfer input batch to ExecNode - virtual void InputReceived(ExecNode* input, ExecBatch batch) = 0; - - /// Signal error to ExecNode - virtual void ErrorReceived(ExecNode* input, Status error) = 0; + /// + /// A node will typically perform some kind of operation on the batch + /// and then call InputReceived on its outputs with the result. + /// + /// Other nodes may need to accumulate some number of inputs before any + /// output can be produced. These nodes will add the batch to some kind + /// of in-memory accumulation queue and return. + virtual Status InputReceived(ExecNode* input, ExecBatch batch) = 0; /// Mark the inputs finished after the given number of batches. /// /// This may be called before all inputs are received. This simply fixes /// the total number of incoming batches for an input, so that the ExecNode /// knows when it has received all input, regardless of order. - virtual void InputFinished(ExecNode* input, int total_batches) = 0; + virtual Status InputFinished(ExecNode* input, int total_batches) = 0; /// \brief Perform any needed initialization /// @@ -189,35 +193,26 @@ class ARROW_EXPORT ExecNode { /// - StartProducing() should not recurse into the inputs, as it is /// handled by ExecPlan::StartProducing() /// - PauseProducing(), ResumeProducing(), StopProducing() may be called - /// concurrently (but only after StartProducing() has returned successfully) + /// concurrently, potentially even before the call to StartProducing + /// has finished. /// - PauseProducing(), ResumeProducing(), StopProducing() may be called /// by the downstream nodes' InputReceived(), ErrorReceived(), InputFinished() /// methods - /// - StopProducing() should recurse into the inputs + /// + /// StopProducing may be called due to an error, by the user (e.g. cancel), or + /// because a node has all the data it needs (e.g. limit, top-k on sorted data). + /// This means the method may be called multiple times and we have the following + /// additional rules /// - StopProducing() must be idempotent + /// - StopProducing() must be forwarded to inputs (this is needed for the limit/top-k + /// case because we may not be stopping the entire plan) - // XXX What happens if StartProducing() calls an output's InputReceived() - // synchronously, and InputReceived() decides to call back into StopProducing() - // (or PauseProducing()) because it received enough data? - // // Right now, since synchronous calls happen in both directions (input to // output and then output to input), a node must be careful to be reentrant // against synchronous calls from its output, *and* also concurrent calls from // other threads. The most reliable solution is to update the internal state // first, and notify outputs only at the end. // - // Alternate rules: - // - StartProducing(), ResumeProducing() can call synchronously into - // its ouputs' consuming methods (InputReceived() etc.) - // - InputReceived(), ErrorReceived(), InputFinished() can call asynchronously - // into its inputs' PauseProducing(), StopProducing() - // - // Alternate API: - // - InputReceived(), ErrorReceived(), InputFinished() return a ProductionHint - // enum: either None (default), PauseProducing, ResumeProducing, StopProducing - // - A method allows passing a ProductionHint asynchronously from an output node - // (replacing PauseProducing(), ResumeProducing(), StopProducing()) - // Concurrent calls to PauseProducing and ResumeProducing can be hard to sequence // as they may travel at different speeds through the plan. // @@ -228,18 +223,10 @@ class ARROW_EXPORT ExecNode { // To resolve this a counter is sent for all calls to pause/resume. Only the call with // the highest counter value is valid. So if a call to PauseProducing(5) comes after // a call to ResumeProducing(6) then the source should continue producing. - // - // If a node has multiple outputs it should emit a new counter value to its inputs - // whenever any of its outputs changes which means the counters sent to inputs may be - // larger than the counters received on its outputs. - // - // A node with multiple outputs will also need to ensure it is applying backpressure if - // any of its outputs is asking to pause /// \brief Start producing /// - /// This must only be called once. If this fails, then other lifecycle - /// methods must not be called. + /// This must only be called once. /// /// This is typically called automatically by ExecPlan::StartProducing(). virtual Status StartProducing() = 0; @@ -252,7 +239,7 @@ class ARROW_EXPORT ExecNode { /// This call is a hint that an output node is currently not willing /// to receive data. /// - /// This may be called any number of times after StartProducing() succeeds. + /// This may be called any number of times. /// However, the node is still free to produce data (which may be difficult /// to prevent anyway if data is produced using multiple threads). virtual void PauseProducing(ExecNode* output, int32_t counter) = 0; @@ -264,34 +251,39 @@ class ARROW_EXPORT ExecNode { /// /// This call is a hint that an output node is willing to receive data again. /// - /// This may be called any number of times after StartProducing() succeeds. + /// This may be called any number of times. virtual void ResumeProducing(ExecNode* output, int32_t counter) = 0; - /// \brief Stop producing definitively to a single output + /// \brief Stop producing new data /// - /// This call is a hint that an output node has completed and is not willing - /// to receive any further data. - virtual void StopProducing(ExecNode* output) = 0; - - /// \brief Stop producing definitively to all outputs - virtual void StopProducing() = 0; - - /// \brief A future which will be marked finished when this node has stopped producing. - virtual Future<> finished() { return finished_; } + /// If this node is a source then the source should stop generating data + /// as quickly as possible. If this node is not a source then there is typically + /// nothing that needs to be done although a node may choose to start ignoring incoming + /// data. + /// + /// This method will be called when an error occurs in the plan + /// This method may also be called by the user if they wish to end a plan early + /// Finally, this method may be called if a node determines it no longer needs any more + /// input (for example, a limit node). + /// + /// This method may be called multiple times. + /// + /// This is not a pause. There will be no way to start the source again after this has + /// been called. + Status StopProducing(); std::string ToString(int indent = 0) const; protected: ExecNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - std::shared_ptr output_schema, int num_outputs); + std::shared_ptr output_schema); - // A helper method to send an error status to all outputs. - // Returns true if the status was an error. - bool ErrorIfNotOk(Status status); + virtual Status StopProducingImpl() = 0; /// Provide extra info to include in the string representation. virtual std::string ToStringExtra(int indent = 0) const; + std::atomic stopped_; ExecPlan* plan_; std::string label_; @@ -299,11 +291,7 @@ class ARROW_EXPORT ExecNode { std::vector input_labels_; std::shared_ptr output_schema_; - int num_outputs_; - NodeVector outputs_; - - // Future to sync finished - Future<> finished_ = Future<>::Make(); + ExecNode* output_ = NULLPTR; }; /// \brief An extensible registry for factories of ExecNodes diff --git a/cpp/src/arrow/compute/exec/filter_node.cc b/cpp/src/arrow/compute/exec/filter_node.cc index b44b25d29fe..b6877f106dc 100644 --- a/cpp/src/arrow/compute/exec/filter_node.cc +++ b/cpp/src/arrow/compute/exec/filter_node.cc @@ -68,33 +68,33 @@ class FilterNode : public MapNode { const char* kind_name() const override { return "FilterNode"; } - Result DoFilter(const ExecBatch& target) { + Result ProcessBatch(ExecBatch batch) override { ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, - SimplifyWithGuarantee(filter_, target.guarantee)); + SimplifyWithGuarantee(filter_, batch.guarantee)); util::tracing::Span span; START_COMPUTE_SPAN(span, "Filter", {{"filter.expression", ToStringExtra()}, {"filter.expression.simplified", simplified_filter.ToString()}, - {"filter.length", target.length}}); + {"filter.length", batch.length}}); ARROW_ASSIGN_OR_RAISE( - Datum mask, ExecuteScalarExpression(simplified_filter, target, + Datum mask, ExecuteScalarExpression(simplified_filter, batch, plan()->query_context()->exec_context())); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); if (mask_scalar.is_valid && mask_scalar.value) { - return target; + return batch; } - return target.Slice(0, 0); + return batch.Slice(0, 0); } // if the values are all scalar then the mask must also be - DCHECK(!std::all_of(target.values.begin(), target.values.end(), + DCHECK(!std::all_of(batch.values.begin(), batch.values.end(), [](const Datum& value) { return value.is_scalar(); })); - auto values = target.values; + auto values = batch.values; for (auto& value : values) { if (value.is_scalar()) continue; ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults())); @@ -102,15 +102,6 @@ class FilterNode : public MapNode { return ExecBatch::Make(std::move(values)); } - void InputReceived(ExecNode* input, ExecBatch batch) override { - DCHECK_EQ(input, inputs_[0]); - auto func = [this](ExecBatch batch) { - auto result = DoFilter(std::move(batch)); - return result; - }; - this->SubmitTask(std::move(func), std::move(batch)); - } - protected: std::string ToStringExtra(int indent = 0) const override { return "filter=" + filter_.ToString(); diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index ffd93591e65..a2d47ddf4e3 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -181,9 +181,9 @@ class HashJoinBasicImpl : public HashJoinImpl { } } - void ProbeBatch_OutputOne(int64_t batch_size_next, ExecBatch* opt_left_key, - ExecBatch* opt_left_payload, ExecBatch* opt_right_key, - ExecBatch* opt_right_payload) { + Status ProbeBatch_OutputOne(int64_t batch_size_next, ExecBatch* opt_left_key, + ExecBatch* opt_left_payload, ExecBatch* opt_right_key, + ExecBatch* opt_right_payload) { ExecBatch result({}, batch_size_next); int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT); int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT); @@ -232,11 +232,12 @@ class HashJoinBasicImpl : public HashJoinImpl { : opt_right_payload->values[from_payload.get(icol)]; } - output_batch_callback_(0, std::move(result)); + ARROW_RETURN_NOT_OK(output_batch_callback_(0, std::move(result))); // Update the counter of produced batches // num_batches_produced_++; + return Status::OK(); } Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state, @@ -405,12 +406,10 @@ class HashJoinBasicImpl : public HashJoinImpl { hash_table_payloads_.Decode(batch_size_next, opt_right_ids)); } - ProbeBatch_OutputOne(batch_size_next, has_left ? &left_key : nullptr, - has_left_payload ? &left_payload : nullptr, - has_right ? &right_key : nullptr, - has_right_payload ? &right_payload : nullptr); - - return Status::OK(); + return ProbeBatch_OutputOne(batch_size_next, has_left ? &left_key : nullptr, + has_left_payload ? &left_payload : nullptr, + has_right ? &right_key : nullptr, + has_right_payload ? &right_payload : nullptr); } Status ProbeBatch_OutputAll(size_t thread_index, const RowEncoder& exec_batch_keys, @@ -684,8 +683,7 @@ class HashJoinBasicImpl : public HashJoinImpl { return Status::Cancelled("Hash join cancelled"); } END_SPAN(span_); - finished_callback_(num_batches_produced_.load()); - return Status::OK(); + return finished_callback_(num_batches_produced_.load()); } Status ScanHashTable(size_t thread_index) { diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index bc053b2f1b6..52c12983d00 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -39,9 +39,9 @@ using arrow::util::AccumulationQueue; class HashJoinImpl { public: - using OutputBatchCallback = std::function; + using OutputBatchCallback = std::function; using BuildFinishedCallback = std::function; - using FinishedCallback = std::function; + using FinishedCallback = std::function; using RegisterTaskGroupCallback = std::function, std::function)>; using StartTaskGroupCallback = std::function; diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 998497ede49..8d4c635a783 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -685,8 +685,7 @@ class HashJoinNode : public ExecNode, public TracedNode { std::unique_ptr schema_mgr, Expression filter, std::unique_ptr impl) : ExecNode(plan, inputs, {"left", "right"}, - /*output_schema=*/std::move(output_schema), - /*num_outputs=*/1), + /*output_schema=*/std::move(output_schema)), join_type_(join_options.join_type), key_cmp_(join_options.key_cmp), filter_(std::move(filter)), @@ -872,58 +871,45 @@ class HashJoinNode : public ExecNode, public TracedNode { return Status::OK(); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { auto scope = TraceInputReceived(batch); ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); if (complete_.load()) { - return; + return Status::OK(); } size_t thread_index = plan_->query_context()->GetThreadIndex(); int side = (input == inputs_[0]) ? 0 : 1; - Status status = side == 0 ? OnProbeSideBatch(thread_index, std::move(batch)) - : OnBuildSideBatch(thread_index, std::move(batch)); - - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - return; + if (side == 0) { + ARROW_RETURN_NOT_OK(OnProbeSideBatch(thread_index, std::move(batch))); + } else { + ARROW_RETURN_NOT_OK(OnBuildSideBatch(thread_index, std::move(batch))); } if (batch_count_[side].Increment()) { - status = side == 0 ? OnProbeSideFinished(thread_index) - : OnBuildSideFinished(thread_index); - - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - return; + if (side == 0) { + return OnProbeSideFinished(thread_index); + } else { + return OnBuildSideFinished(thread_index); } } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - DCHECK_EQ(input, inputs_[0]); - StopProducing(); - outputs_[0]->ErrorReceived(this, std::move(error)); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); size_t thread_index = plan_->query_context()->GetThreadIndex(); int side = (input == inputs_[0]) ? 0 : 1; if (batch_count_[side].SetTotal(total_batches)) { - Status status = side == 0 ? OnProbeSideFinished(thread_index) - : OnBuildSideFinished(thread_index); - - if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - return; + if (side == 0) { + return OnProbeSideFinished(thread_index); + } else { + return OnBuildSideFinished(thread_index); } } + return Status::OK(); } Status Init() override { @@ -962,9 +948,9 @@ class HashJoinNode : public ExecNode, public TracedNode { [ctx](int task_group_id, int64_t num_tasks) { return ctx->StartTaskGroup(task_group_id, num_tasks); }, - [this](int64_t, ExecBatch batch) { this->OutputBatchCallback(batch); }, + [this](int64_t, ExecBatch batch) { return this->OutputBatchCallback(batch); }, [this](int64_t total_num_batches) { - this->FinishedCallback(total_num_batches); + return this->FinishedCallback(total_num_batches); })); task_group_probe_ = ctx->RegisterTaskGroup( @@ -994,18 +980,12 @@ class HashJoinNode : public ExecNode, public TracedNode { // TODO(ARROW-16246) } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - for (auto&& input : inputs_) { - input->StopProducing(this); - } - } - - void StopProducing() override { + Status StopProducingImpl() override { bool expected = false; if (complete_.compare_exchange_strong(expected, true)) { - impl_->Abort([this]() { finished_.MarkFinished(); }); + impl_->Abort([]() {}); } + return Status::OK(); } protected: @@ -1014,16 +994,16 @@ class HashJoinNode : public ExecNode, public TracedNode { } private: - void OutputBatchCallback(ExecBatch batch) { - outputs_[0]->InputReceived(this, std::move(batch)); + Status OutputBatchCallback(ExecBatch batch) { + return output_->InputReceived(this, std::move(batch)); } - void FinishedCallback(int64_t total_num_batches) { + Status FinishedCallback(int64_t total_num_batches) { bool expected = false; if (complete_.compare_exchange_strong(expected, true)) { - outputs_[0]->InputFinished(this, static_cast(total_num_batches)); - finished_.MarkFinished(); + return output_->InputFinished(this, static_cast(total_num_batches)); } + return Status::OK(); } private: diff --git a/cpp/src/arrow/compute/exec/map_node.cc b/cpp/src/arrow/compute/exec/map_node.cc index a5e73fb6f33..2576cb6259a 100644 --- a/cpp/src/arrow/compute/exec/map_node.cc +++ b/cpp/src/arrow/compute/exec/map_node.cc @@ -36,20 +36,16 @@ namespace compute { MapNode::MapNode(ExecPlan* plan, std::vector inputs, std::shared_ptr output_schema) : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"}, - std::move(output_schema), - /*num_outputs=*/1) {} + std::move(output_schema)) {} -void MapNode::ErrorReceived(ExecNode* input, Status error) { +Status MapNode::InputFinished(ExecNode* input, int total_batches) { DCHECK_EQ(input, inputs_[0]); - outputs_[0]->ErrorReceived(this, std::move(error)); -} - -void MapNode::InputFinished(ExecNode* input, int total_batches) { - DCHECK_EQ(input, inputs_[0]); - outputs_[0]->InputFinished(this, total_batches); + EVENT_ON_CURRENT_SPAN("InputFinished", {{"batches.length", total_batches}}); + ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_batches)); if (input_counter_.SetTotal(total_batches)) { this->Finish(); } + return Status::OK(); } Status MapNode::StartProducing() { @@ -65,54 +61,22 @@ void MapNode::ResumeProducing(ExecNode* output, int32_t counter) { inputs_[0]->ResumeProducing(this, counter); } -void MapNode::StopProducing(ExecNode* output) { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); -} +Status MapNode::StopProducingImpl() { return Status::OK(); } -void MapNode::StopProducing() { - if (input_counter_.Cancel()) { - this->Finish(); - } - inputs_[0]->StopProducing(this); -} - -void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, - ExecBatch batch) { +Status MapNode::InputReceived(ExecNode* input, ExecBatch batch) { auto scope = TraceInputReceived(batch); - Status status; - // This will be true if the node is stopped early due to an error or manual - // cancellation - if (input_counter_.Completed()) { - return; - } - auto task = [this, map_fn, batch]() { - auto guarantee = batch.guarantee; - auto output_batch = map_fn(std::move(batch)); - if (ErrorIfNotOk(output_batch.status())) { - return output_batch.status(); - } - output_batch->guarantee = guarantee; - outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe()); - return Status::OK(); - }; - - status = task(); - if (!status.ok()) { - if (input_counter_.Cancel()) { - this->Finish(status); - } - inputs_[0]->StopProducing(this); - return; - } + DCHECK_EQ(input, inputs_[0]); + compute::Expression guarantee = batch.guarantee; + ARROW_ASSIGN_OR_RAISE(auto output_batch, ProcessBatch(std::move(batch))); + output_batch.guarantee = guarantee; + ARROW_RETURN_NOT_OK(output_->InputReceived(this, std::move(output_batch))); if (input_counter_.Increment()) { this->Finish(); } + return Status::OK(); } -void MapNode::Finish(Status finish_st /*= Status::OK()*/) { - this->finished_.MarkFinished(finish_st); -} +void MapNode::Finish() {} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/map_node.h b/cpp/src/arrow/compute/exec/map_node.h index 9211afc0060..2e62030d64f 100644 --- a/cpp/src/arrow/compute/exec/map_node.h +++ b/cpp/src/arrow/compute/exec/map_node.h @@ -15,14 +15,6 @@ // specific language governing permissions and limitations // under the License. -/// \brief MapNode is an ExecNode type class which process a task like filter/project -/// (See SubmitTask method) to each given ExecBatch object, which have one input, one -/// output, and are pure functions on the input -/// -/// A simple parallel runner is created with a "map_fn" which is just a function that -/// takes a batch in and returns a batch. This simple parallel runner also needs an -/// executor (use simple synchronous runner if there is no executor) - #pragma once #include @@ -42,14 +34,19 @@ namespace arrow { namespace compute { +/// A utility base class for simple exec nodes with one input +/// +/// Pause/Resume Producing are forwarded appropriately +/// There is nothing to do in StopProducingImpl +/// +/// An AtomicCounter is used to keep track of when all data has arrived. When it +/// has the Finish() method will be invoked class ARROW_EXPORT MapNode : public ExecNode, public TracedNode { public: MapNode(ExecPlan* plan, std::vector inputs, std::shared_ptr output_schema); - void ErrorReceived(ExecNode* input, Status error) override; - - void InputFinished(ExecNode* input, int total_batches) override; + Status InputFinished(ExecNode* input, int total_batches) override; Status StartProducing() override; @@ -57,14 +54,21 @@ class ARROW_EXPORT MapNode : public ExecNode, public TracedNode { void ResumeProducing(ExecNode* output, int32_t counter) override; - void StopProducing(ExecNode* output) override; - - void StopProducing() override; + Status InputReceived(ExecNode* input, ExecBatch batch) override; protected: - void SubmitTask(std::function(ExecBatch)> map_fn, ExecBatch batch); - - virtual void Finish(Status finish_st = Status::OK()); + Status StopProducingImpl() override; + + /// Transform a batch + /// + /// The output batch will have the same guarantee as the input batch + /// If this was the last batch this call may trigger Finish() + virtual Result ProcessBatch(ExecBatch batch) = 0; + + /// Function called after all data has been received + /// + /// By default this does nothing. Override this to provide a custom implementation. + virtual void Finish(); protected: // Counter for the number of batches received diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 0ef75cbedcf..0d63600101e 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -321,6 +321,9 @@ class ARROW_EXPORT SinkNodeConsumer { /// \brief Signal to the consumer that the last batch has been delivered /// /// The returned future should only finish when all outstanding tasks have completed + /// + /// If the plan is ended early or aborts due to an error then this will not be + /// called. virtual Future<> Finish() = 0; }; diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 4f68cbab205..b0c0237e936 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -38,6 +38,7 @@ #include "arrow/util/thread_pool.h" #include "arrow/util/vector.h" +using testing::Contains; using testing::ElementsAre; using testing::ElementsAreArray; using testing::HasSubstr; @@ -55,58 +56,33 @@ TEST(ExecPlanConstruction, Empty) { } TEST(ExecPlanConstruction, SingleNode) { + // Single node that is both source and sink ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/0); + auto node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*is_sink=*/true); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ElementsAre(node)); - ASSERT_THAT(plan->sinks(), ElementsAre(node)); + ASSERT_THAT(plan->nodes(), ElementsAre(node)); + // Single source node that is not supposed to be a sink (invalid) ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/1); + node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}); // Output not bound ASSERT_THAT(plan->Validate(), Raises(StatusCode::Invalid)); } TEST(ExecPlanConstruction, SourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*inputs=*/{}, /*num_outputs=*/1); - auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{source}, /*num_outputs=*/0); + auto source = MakeDummyNode(plan.get(), "source", /*inputs=*/{}); + auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{source}, /*is_sink=*/true); ASSERT_OK(plan->Validate()); - EXPECT_THAT(plan->sources(), ElementsAre(source)); - EXPECT_THAT(plan->sinks(), ElementsAre(sink)); -} - -TEST(ExecPlanConstruction, MultipleNode) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - - auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2); - - auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1); - - auto process1 = - MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2); - - auto process2 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1, source2}, - /*num_outputs=*/1); - - auto process3 = - MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, process2, process1}, - /*num_outputs=*/1); - - auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0); - - ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ElementsAre(source1, source2)); - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); + EXPECT_THAT(plan->nodes(), ElementsAre(source, sink)); } TEST(ExecPlanConstruction, AutoLabel) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "", /*inputs=*/{}, /*num_outputs=*/2); - auto source2 = - MakeDummyNode(plan.get(), "some_label", /*inputs=*/{}, /*num_outputs=*/1); - auto source3 = MakeDummyNode(plan.get(), "", /*inputs=*/{}, /*num_outputs=*/2); + auto source1 = MakeDummyNode(plan.get(), "", /*inputs=*/{}); + auto source2 = MakeDummyNode(plan.get(), "some_label", /*inputs=*/{}); + auto source3 = MakeDummyNode(plan.get(), "", /*inputs=*/{}); ASSERT_EQ("0", source1->label()); ASSERT_EQ("some_label", source2->label()); @@ -133,32 +109,32 @@ TEST(ExecPlan, DummyStartProducing) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2, + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*is_sink=*/false, t.start_producing_func(), t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1, + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*is_sink=*/false, t.start_producing_func(), t.stop_producing_func()); auto process1 = - MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2, + MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*is_sink=*/false, t.start_producing_func(), t.stop_producing_func()); auto process2 = MakeDummyNode(plan.get(), "process2", /*inputs=*/{process1, source2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + /*is_sink=*/false, t.start_producing_func(), t.stop_producing_func()); auto process3 = - MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, source1, process2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + MakeDummyNode(plan.get(), "process3", /*inputs=*/{process2}, + /*is_sink=*/false, t.start_producing_func(), t.stop_producing_func()); - MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0, + MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*is_sink=*/true, t.start_producing_func(), t.stop_producing_func()); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); ASSERT_EQ(t.stopped.size(), 0); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); // Note that any correct reverse topological order may do ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1", "source2", "source1")); @@ -169,35 +145,35 @@ TEST(ExecPlan, DummyStartProducing) { ASSERT_THAT(t.stopped, ElementsAre("source1", "source2", "process1", "process2", "process3", "sink")); - ASSERT_THAT(plan->StartProducing(), - Raises(StatusCode::Invalid, HasSubstr("restarted"))); + plan->StartProducing(); + ASSERT_THAT(plan->finished(), Finishes(Raises(StatusCode::Invalid, + HasSubstr("plan had already finished")))); } TEST(ExecPlan, DummyStartProducingError) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode( - plan.get(), "source1", /*num_inputs=*/{}, /*num_outputs=*/2, - t.start_producing_func(Status::NotImplemented("zzz")), t.stop_producing_func()); + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*is_sink=*/false, + t.start_producing_func(Status::NotImplemented("zzz")), + t.stop_producing_func()); - auto source2 = - MakeDummyNode(plan.get(), "source2", /*num_inputs=*/{}, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*is_sink=*/false, + t.start_producing_func(), t.stop_producing_func()); auto process1 = MakeDummyNode( - plan.get(), "process1", /*num_inputs=*/{source1}, /*num_outputs=*/2, + plan.get(), "process1", /*inputs=*/{source1}, /*is_sink=*/false, t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func()); auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/{process1, source2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + MakeDummyNode(plan.get(), "process2", /*inputs=*/{process1, source2}, + /*is_sink=*/false, t.start_producing_func(), t.stop_producing_func()); auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/{process1, source1, process2}, - /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + MakeDummyNode(plan.get(), "process3", /*inputs=*/{process2}, + /*is_sink=*/false, t.start_producing_func(), t.stop_producing_func()); - MakeDummyNode(plan.get(), "sink", /*num_inputs=*/{process3}, /*num_outputs=*/0, + MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*is_sink=*/true, t.start_producing_func(), t.stop_producing_func()); ASSERT_OK(plan->Validate()); @@ -205,10 +181,12 @@ TEST(ExecPlan, DummyStartProducingError) { ASSERT_EQ(t.stopped.size(), 0); // `process1` raises IOError - ASSERT_THAT(plan->StartProducing(), Raises(StatusCode::IOError)); + plan->StartProducing(); + ASSERT_THAT(plan->finished(), Finishes(Raises(StatusCode::IOError))); ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1")); - // Nodes that started successfully were stopped in reverse order - ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink")); + // All nodes will be stopped when an abort happens + ASSERT_THAT(t.stopped, ElementsAre("process2", "process1", "source1", "source2", + "process3", "sink")); } TEST(ExecPlanExecution, SourceSink) { @@ -243,7 +221,7 @@ TEST(ExecPlanExecution, UseSinkAfterExecution) { {"sink", SinkNodeOptions{&sink_gen}}, }) .AddToPlan(plan.get())); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); } ASSERT_FINISHES_AND_RAISES(Invalid, sink_gen()); @@ -389,6 +367,16 @@ TEST(ExecPlanExecution, RecordBatchReaderSourceSinkError) { TestRecordBatchReaderSourceSinkError(ToRecordBatchReader); } +void CheckFinishesCancelledOrOk(const Future<>& fut) { + // There is a race condition with most tests that cancel plans. If the + // cancel call comes in too slowly then the plan might have already finished + // ok. + ASSERT_TRUE(fut.Wait(kDefaultAssertFinishesWaitSeconds)); + if (!fut.status().ok()) { + ASSERT_TRUE(fut.status().IsCancelled()); + } +} + TEST(ExecPlanExecution, SinkNodeBackpressure) { std::optional batch = ExecBatchFromJSON({int32(), boolean()}, @@ -414,7 +402,7 @@ TEST(ExecPlanExecution, SinkNodeBackpressure) { }) .AddToPlan(plan.get())); ASSERT_TRUE(backpressure_monitor); - ARROW_EXPECT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FALSE(backpressure_monitor->is_paused()); @@ -445,7 +433,7 @@ TEST(ExecPlanExecution, SinkNodeBackpressure) { // Cleanup batch_producer.producer().Push(IterationEnd>()); plan->StopProducing(); - ASSERT_FINISHES_OK(plan->finished()); + CheckFinishesCancelledOrOk(plan->finished()); } TEST(ExecPlan, ToString) { @@ -461,8 +449,8 @@ TEST(ExecPlan, ToString) { {"sink", SinkNodeOptions{&sink_gen}}, }) .AddToPlan(plan.get())); - EXPECT_EQ(plan->sources()[0]->ToString(), R"(:SourceNode{})"); - EXPECT_EQ(plan->sinks()[0]->ToString(), R"(:SinkNode{})"); + EXPECT_EQ(plan->nodes()[0]->ToString(), R"(:SourceNode{})"); + EXPECT_EQ(plan->nodes()[1]->ToString(), R"(:SinkNode{})"); EXPECT_EQ(plan->ToString(), R"(ExecPlan with 2 nodes: :SinkNode{} :SourceNode{} @@ -638,9 +626,7 @@ TEST(ExecPlanExecution, SourceConsumingSink) { basic_data.gen(parallel, slow)))); ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source}, ConsumingSinkNodeOptions(consumer))); - ASSERT_OK(plan->StartProducing()); - // Source should finish fairly quickly - ASSERT_FINISHES_OK(source->finished()); + plan->StartProducing(); SleepABit(); // Consumer isn't finished and so plan shouldn't have finished AssertNotFinished(plan->finished()); @@ -671,9 +657,7 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) { SourceNodeOptions(basic_data.schema, basic_data.gen(parallel, slow)))); ASSERT_OK(MakeExecNode("table_sink", plan.get(), {source}, options)); - ASSERT_OK(plan->StartProducing()); - // Source should finish fairly quickly - ASSERT_FINISHES_OK(source->finished()); + plan->StartProducing(); SleepABit(); ASSERT_OK_AND_ASSIGN(auto expected, TableFromExecBatches(basic_data.schema, basic_data.batches)); @@ -713,10 +697,11 @@ TEST(ExecPlanExecution, ConsumingSinkNames) { ConsumingSinkNodeOptions(consumer, names))); if (names.size() != 0 && names.size() != static_cast(basic_data.batches[0].num_values())) { - ASSERT_RAISES(Invalid, plan->StartProducing()); + plan->StartProducing(); + ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished()); } else { auto expected_names = names.size() == 0 ? basic_data.schema->field_names() : names; - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); ASSERT_EQ(expected_names, consumer->schema_->field_names()); } @@ -757,7 +742,8 @@ TEST(ExecPlanExecution, ConsumingSinkError) { Declaration plan = Declaration::Sequence( {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))}, {"consuming_sink", ConsumingSinkNodeOptions(consumer)}}); - // Since the source node is not parallel the entire plan is run during StartProducing + // Since the source node is not parallel the entire plan is run during + // StartProducing ASSERT_RAISES(Invalid, DeclarationToStatus(std::move(plan))); } } @@ -851,9 +837,9 @@ TEST(ExecPlanExecution, StressSourceGroupedSumStop) { .AddToPlan(plan.get())); ASSERT_OK(plan->Validate()); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); plan->StopProducing(); - ASSERT_FINISHES_OK(plan->finished()); + CheckFinishesCancelledOrOk(plan->finished()); } } } @@ -867,27 +853,43 @@ TEST(ExecPlanExecution, StressSourceSinkStopped) { int num_batches = (slow && !parallel) ? 30 : 300; - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); AsyncGenerator> sink_gen; auto random_data = MakeRandomBatches( schema({field("a", int32()), field("b", boolean())}), num_batches); - ASSERT_OK(Declaration::Sequence( - { - {"source", SourceNodeOptions{random_data.schema, - random_data.gen(parallel, slow)}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); - - ASSERT_OK(plan->Validate()); - ASSERT_OK(plan->StartProducing()); - - EXPECT_THAT(sink_gen(), Finishes(ResultWith(Optional(random_data.batches[0])))); + Declaration decl = Declaration::Sequence({ + {"source", + SourceNodeOptions{random_data.schema, random_data.gen(parallel, slow)}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }); + + std::shared_ptr plan; + Future> first_batch_fut = + ::arrow::internal::RunSynchronously>>( + [&](::arrow::internal::Executor* executor) + -> Future> { + ExecContext ctx = ExecContext(default_memory_pool(), executor); + ARROW_ASSIGN_OR_RAISE(plan, ExecPlan::Make(ctx)); + ARROW_RETURN_NOT_OK(decl.AddToPlan(plan.get())); + ARROW_RETURN_NOT_OK(plan->Validate()); + plan->StartProducing(); + return sink_gen(); + }, + parallel); + + if (parallel) { + ASSERT_FINISHES_OK_AND_ASSIGN(std::optional batch, first_batch_fut); + ASSERT_TRUE(batch.has_value()); + ASSERT_THAT(random_data.batches, Contains(*batch)); + } else { + EXPECT_THAT(first_batch_fut, + Finishes(ResultWith(Optional(random_data.batches[0])))); + } plan->StopProducing(); - ASSERT_THAT(plan->finished(), Finishes(Ok())); + Future<> finished = plan->finished(); + CheckFinishesCancelledOrOk(plan->finished()); } } } diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc index 239f2d38de1..e355d012e18 100644 --- a/cpp/src/arrow/compute/exec/project_node.cc +++ b/cpp/src/arrow/compute/exec/project_node.cc @@ -77,31 +77,22 @@ class ProjectNode : public MapNode, public TracedNode { const char* kind_name() const override { return "ProjectNode"; } - Result DoProject(const ExecBatch& target) { + Result ProcessBatch(ExecBatch batch) override { std::vector values{exprs_.size()}; for (size_t i = 0; i < exprs_.size(); ++i) { util::tracing::Span span; START_COMPUTE_SPAN(span, "Project", {{"project.type", exprs_[i].type()->ToString()}, - {"project.length", target.length}, + {"project.length", batch.length}, {"project.expression", exprs_[i].ToString()}}); ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, - SimplifyWithGuarantee(exprs_[i], target.guarantee)); + SimplifyWithGuarantee(exprs_[i], batch.guarantee)); ARROW_ASSIGN_OR_RAISE( - values[i], ExecuteScalarExpression(simplified_expr, target, + values[i], ExecuteScalarExpression(simplified_expr, batch, plan()->query_context()->exec_context())); } - return ExecBatch{std::move(values), target.length}; - } - - void InputReceived(ExecNode* input, ExecBatch batch) override { - DCHECK_EQ(input, inputs_[0]); - auto func = [this](ExecBatch batch) { - auto result = DoProject(std::move(batch)); - return result; - }; - this->SubmitTask(std::move(func), std::move(batch)); + return ExecBatch{std::move(values), batch.length}; } protected: diff --git a/cpp/src/arrow/compute/exec/query_context.cc b/cpp/src/arrow/compute/exec/query_context.cc index f4664f1fe93..65a1a674181 100644 --- a/cpp/src/arrow/compute/exec/query_context.cc +++ b/cpp/src/arrow/compute/exec/query_context.cc @@ -60,28 +60,25 @@ Result> QueryContext::BeginExternalTask(std::string_view name) { return Future<>{}; } -Status QueryContext::ScheduleTask(std::function fn, std::string_view name) { +void QueryContext::ScheduleTask(std::function fn, std::string_view name) { ::arrow::internal::Executor* exec = executor(); // Adds a task which submits fn to the executor and tracks its progress. If we're // already stopping then the task is ignored and fn is not executed. async_scheduler_->AddSimpleTask([exec, fn]() { return exec->Submit(std::move(fn)); }, name); - return Status::OK(); } -Status QueryContext::ScheduleTask(std::function fn, - std::string_view name) { +void QueryContext::ScheduleTask(std::function fn, std::string_view name) { std::function indexed_fn = [this, fn]() { size_t thread_index = GetThreadIndex(); return fn(thread_index); }; - return ScheduleTask(std::move(indexed_fn), name); + ScheduleTask(std::move(indexed_fn), name); } -Status QueryContext::ScheduleIOTask(std::function fn, std::string_view name) { +void QueryContext::ScheduleIOTask(std::function fn, std::string_view name) { async_scheduler_->AddSimpleTask( [this, fn]() { return io_context_.executor()->Submit(std::move(fn)); }, name); - return Status::OK(); } int QueryContext::RegisterTaskGroup(std::function task, diff --git a/cpp/src/arrow/compute/exec/query_context.h b/cpp/src/arrow/compute/exec/query_context.h index 8d5379f8f28..10f151b3188 100644 --- a/cpp/src/arrow/compute/exec/query_context.h +++ b/cpp/src/arrow/compute/exec/query_context.h @@ -83,19 +83,19 @@ class ARROW_EXPORT QueryContext { /// /// \param fn The task to run. Takes no arguments and returns a Status. /// \param name A name to give the task for traceability and debugging - Status ScheduleTask(std::function fn, std::string_view name); + void ScheduleTask(std::function fn, std::string_view name); /// \brief Add a single function as a task to the query's task group /// on the compute threadpool. /// /// \param fn The task to run. Takes the thread index and returns a Status. /// \param name A name to give the task for traceability and debugging - Status ScheduleTask(std::function fn, std::string_view name); + void ScheduleTask(std::function fn, std::string_view name); /// \brief Add a single function as a task to the query's task group on /// the IO thread pool /// /// \param fn The task to run. Returns a status. /// \param name A name to give the task for traceability and debugging - Status ScheduleIOTask(std::function fn, std::string_view name); + void ScheduleIOTask(std::function fn, std::string_view name); // Register/Start TaskGroup is a way of performing a "Parallel For" pattern: // - The task function takes the thread index and the index of the task diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 62c226885e5..7481245ccf6 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -103,8 +103,7 @@ class SinkNode : public ExecNode, public TracedNode { AsyncGenerator>* generator, std::shared_ptr* schema, BackpressureOptions backpressure, BackpressureMonitor** backpressure_monitor_out) - : ExecNode(plan, std::move(inputs), {"collected"}, {}, - /*num_outputs=*/0), + : ExecNode(plan, std::move(inputs), {"collected"}, {}), backpressure_queue_(backpressure.resume_if_below, backpressure.pause_if_above), push_gen_(), producer_(push_gen_.producer()), @@ -160,11 +159,20 @@ class SinkNode : public ExecNode, public TracedNode { [[noreturn]] void PauseProducing(ExecNode* output, int32_t counter) override { NoOutputs(); } - [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); } - void StopProducing() override { - Finish(); - inputs_[0]->StopProducing(this); + Status StopProducingImpl() override { + // An AsyncGenerator must always be finished. So we go ahead and + // close the producer. However, for custom sink nodes, we don't want + // to bother ordering and sending output so we don't call Finish + producer_.Close(); + return Status::OK(); + } + + Status Validate() const override { + if (output_ != nullptr) { + return Status::Invalid("Sink node '", label(), "' has an output"); + } + return ExecNode::Validate(); } void RecordBackpressureBytesUsed(const ExecBatch& batch) { @@ -193,41 +201,32 @@ class SinkNode : public ExecNode, public TracedNode { } } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { auto scope = TraceInputReceived(batch); + DCHECK_EQ(input, inputs_[0]); RecordBackpressureBytesUsed(batch); bool did_push = producer_.Push(std::move(batch)); - if (!did_push) return; // producer_ was Closed already + if (!did_push) return Status::OK(); // producer_ was Closed already if (input_counter_.Increment()) { - Finish(); - } - } - - void ErrorReceived(ExecNode* input, Status error) override { - DCHECK_EQ(input, inputs_[0]); - - producer_.Push(std::move(error)); - - if (input_counter_.Cancel()) { - Finish(); + return Finish(); } - inputs_[0]->StopProducing(this); + return Status::OK(); } - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { if (input_counter_.SetTotal(total_batches)) { - Finish(); + return Finish(); } + return Status::OK(); } protected: - virtual void Finish() { - if (producer_.Close()) { - finished_.MarkFinished(); - } + virtual Status Finish() { + producer_.Close(); + return Status::OK(); } static Status ValidateOptions(const SinkNodeOptions& sink_options) { @@ -268,8 +267,7 @@ class ConsumingSinkNode : public ExecNode, ConsumingSinkNode(ExecPlan* plan, std::vector inputs, std::shared_ptr consumer, std::vector names) - : ExecNode(plan, std::move(inputs), {"to_consume"}, {}, - /*num_outputs=*/0), + : ExecNode(plan, std::move(inputs), {"to_consume"}, {}), consumer_(std::move(consumer)), names_(std::move(names)) {} @@ -320,63 +318,46 @@ class ConsumingSinkNode : public ExecNode, [[noreturn]] void PauseProducing(ExecNode* output, int32_t counter) override { NoOutputs(); } - [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); } void Pause() override { inputs_[0]->PauseProducing(this, ++backpressure_counter_); } void Resume() override { inputs_[0]->ResumeProducing(this, ++backpressure_counter_); } - void StopProducing() override { - if (input_counter_.Cancel()) { - Finish(Status::OK()); - } + Status StopProducingImpl() override { + // We do not call the consumer's finish method if ending early. This might leave us + // with half-written data files (in a dataset write for example) and that is ok. + return Status::OK(); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { auto scope = TraceInputReceived(batch); + DCHECK_EQ(input, inputs_[0]); // This can happen if an error was received and the source hasn't yet stopped. Since // we have already called consumer_->Finish we don't want to call consumer_->Consume if (input_counter_.Completed()) { - return; - } - - Status consumption_status = consumer_->Consume(std::move(batch)); - if (!consumption_status.ok()) { - if (input_counter_.Cancel()) { - Finish(std::move(consumption_status)); - } - inputs_[0]->StopProducing(this); - return; + return Status::OK(); } + ARROW_RETURN_NOT_OK(consumer_->Consume(std::move(batch))); if (input_counter_.Increment()) { - Finish(Status::OK()); + Finish(); } + return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override { - DCHECK_EQ(input, inputs_[0]); - - if (input_counter_.Cancel()) Finish(error); - - inputs_[0]->StopProducing(this); - } - - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { if (input_counter_.SetTotal(total_batches)) { - Finish(Status::OK()); + Finish(); } + return Status::OK(); } protected: - void Finish(const Status& finish_st) { - if (finish_st.ok()) { - plan_->query_context()->async_scheduler()->AddSimpleTask( - [this] { return consumer_->Finish(); }, "ConsumingSinkNode::Finish"sv); - } - finished_.MarkFinished(finish_st); + void Finish() { + plan_->query_context()->async_scheduler()->AddSimpleTask( + [this] { return consumer_->Finish(); }, "ConsumingSinkNode::Finish"sv); } AtomicCounter input_counter_; @@ -465,26 +446,20 @@ struct OrderBySinkNode final : public SinkNode { return ValidateCommonOrderOptions(options); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { auto scope = TraceInputReceived(batch); DCHECK_EQ(input, inputs_[0]); - auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(), - plan()->query_context()->memory_pool()); - if (ErrorIfNotOk(maybe_batch.status())) { - StopProducing(); - if (input_counter_.Cancel()) { - finished_.MarkFinished(maybe_batch.status()); - } - return; - } - auto record_batch = maybe_batch.MoveValueUnsafe(); + ARROW_ASSIGN_OR_RAISE(auto record_batch, + batch.ToRecordBatch(inputs_[0]->output_schema(), + plan()->query_context()->memory_pool())); impl_->InputReceived(std::move(record_batch)); if (input_counter_.Increment()) { - Finish(); + return Finish(); } + return Status::OK(); } protected: @@ -502,12 +477,10 @@ struct OrderBySinkNode final : public SinkNode { return Status::OK(); } - void Finish() override { - Status st = DoFinish(); - if (ErrorIfNotOk(st)) { - producer_.Push(std::move(st)); - } - SinkNode::Finish(); + Status Finish() override { + util::tracing::Span span; + ARROW_RETURN_NOT_OK(DoFinish()); + return SinkNode::Finish(); } protected: diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index af43b63d0cf..fbd84e3ca4e 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -53,8 +53,7 @@ namespace { struct SourceNode : ExecNode, public TracedNode { SourceNode(ExecPlan* plan, std::shared_ptr output_schema, AsyncGenerator> generator) - : ExecNode(plan, {}, {}, std::move(output_schema), - /*num_outputs=*/1), + : ExecNode(plan, {}, {}, std::move(output_schema)), generator_(std::move(generator)) {} static Result Make(ExecPlan* plan, std::vector inputs, @@ -70,9 +69,40 @@ struct SourceNode : ExecNode, public TracedNode { [[noreturn]] static void NoInputs() { Unreachable("no inputs; this should never be called"); } - [[noreturn]] void InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } - [[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); } - [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); } + [[noreturn]] Status InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } + [[noreturn]] Status InputFinished(ExecNode*, int) override { NoInputs(); } + + void SliceAndDeliverMorsel(const ExecBatch& morsel) { + bool use_legacy_batching = plan_->query_context()->options().use_legacy_batching; + int64_t morsel_length = static_cast(morsel.length); + if (use_legacy_batching || morsel_length == 0) { + // For various reasons (e.g. ARROW-13982) we pass empty batches + // through + batch_count_++; + } else { + int num_batches = + static_cast(bit_util::CeilDiv(morsel_length, ExecPlan::kMaxBatchSize)); + batch_count_ += num_batches; + } + plan_->query_context()->ScheduleTask( + [=]() { + int64_t offset = 0; + do { + int64_t batch_size = + std::min(morsel_length - offset, ExecPlan::kMaxBatchSize); + // In order for the legacy batching model to work we must + // not slice batches from the source + if (use_legacy_batching) { + batch_size = morsel_length; + } + ExecBatch batch = morsel.Slice(offset, batch_size); + offset += batch_size; + ARROW_RETURN_NOT_OK(output_->InputReceived(this, std::move(batch))); + } while (offset < morsel.length); + return Status::OK(); + }, + "SourceNode::ProcessMorsel"); + } Status StartProducing() override { NoteStartProducing(ToStringExtra()); @@ -89,95 +119,59 @@ struct SourceNode : ExecNode, public TracedNode { } CallbackOptions options; - auto executor = plan()->query_context()->executor(); - if (executor) { - // These options will transfer execution to the desired Executor if necessary. - // This can happen for in-memory scans where batches didn't require - // any CPU work to decode. Otherwise, parsing etc should have already - // been placed us on the desired Executor and no queues will be pushed to. - options.executor = executor; - options.should_schedule = ShouldSchedule::IfDifferentExecutor; - } + // These options will transfer execution to the desired Executor if necessary. + // This can happen for in-memory scans where batches didn't require + // any CPU work to decode. Otherwise, parsing etc should have already + // been placed us on the desired Executor and no queues will be pushed to. + options.executor = plan()->query_context()->executor(); + options.should_schedule = ShouldSchedule::IfDifferentExecutor; ARROW_ASSIGN_OR_RAISE(Future<> scan_task, plan_->query_context()->BeginExternalTask( "SourceNode::DatasetScan")); if (!scan_task.is_valid()) { - finished_.MarkFinished(); // Plan has already been aborted, no need to start scanning return Status::OK(); } auto fut = Loop([this, options] { - std::unique_lock lock(mutex_); - if (stop_requested_) { - return Future>::MakeFinished(Break(batch_count_)); - } - lock.unlock(); - util::tracing::Span fetch_batch_span; - auto fetch_batch_scope = - START_SCOPED_SPAN(fetch_batch_span, "SourceNode::ReadBatch"); - return generator_().Then( - [this, fetch_batch_span = std::move(fetch_batch_span)]( - const std::optional& maybe_morsel) mutable - -> Future> { - fetch_batch_span.reset(); - std::unique_lock lock(mutex_); - if (IsIterationEnd(maybe_morsel) || stop_requested_) { - return Break(batch_count_); - } - lock.unlock(); - bool use_legacy_batching = - plan_->query_context()->options().use_legacy_batching; - ExecBatch morsel = std::move(*maybe_morsel); - int64_t morsel_length = static_cast(morsel.length); - if (use_legacy_batching || morsel_length == 0) { - // For various reasons (e.g. ARROW-13982) we pass empty batches - // through - batch_count_++; - } else { - int num_batches = static_cast( - bit_util::CeilDiv(morsel_length, ExecPlan::kMaxBatchSize)); - batch_count_ += num_batches; - } - RETURN_NOT_OK(plan_->query_context()->ScheduleTask( - [this, morsel = std::move(morsel), morsel_length, - use_legacy_batching]() { - int64_t offset = 0; - do { - int64_t batch_size = std::min( - morsel_length - offset, ExecPlan::kMaxBatchSize); - // In order for the legacy batching model to work we must - // not slice batches from the source - if (use_legacy_batching) { - batch_size = morsel_length; - } - ExecBatch batch = morsel.Slice(offset, batch_size); - offset += batch_size; - outputs_[0]->InputReceived(this, std::move(batch)); - } while (offset < morsel.length); - return Status::OK(); - }, - "SourceNode::ProcessMorsel")); - lock.lock(); - if (!backpressure_future_.is_finished()) { - EVENT_ON_CURRENT_SPAN("SourceNode::BackpressureApplied"); - return backpressure_future_.Then( - []() -> ControlFlow { return Continue(); }); - } - return Future>::MakeFinished(Continue()); - }, - [this](const Status& error) -> ControlFlow { - outputs_[0]->ErrorReceived(this, error); - return Break(batch_count_); - }, - options); - }) - .Then( - [this, scan_task](int total_batches) mutable { - outputs_[0]->InputFinished(this, total_batches); - scan_task.MarkFinished(); - finished_.MarkFinished(); - }, - {}, options); - if (!executor && finished_.is_finished()) return finished_.status(); + std::unique_lock lock(mutex_); + if (stop_requested_) { + return Future>::MakeFinished(Break(batch_count_)); + } + lock.unlock(); + + util::tracing::Span fetch_batch_span; + auto fetch_batch_scope = + START_SCOPED_SPAN(fetch_batch_span, "SourceNode::ReadBatch"); + return generator_().Then( + [this]( + const std::optional& morsel_or_end) -> Future> { + std::unique_lock lock(mutex_); + if (IsIterationEnd(morsel_or_end) || stop_requested_) { + return Break(batch_count_); + } + lock.unlock(); + SliceAndDeliverMorsel(*morsel_or_end); + lock.lock(); + if (!backpressure_future_.is_finished()) { + EVENT_ON_CURRENT_SPAN("SourceNode::BackpressureApplied"); + return backpressure_future_.Then( + []() -> ControlFlow { return Continue(); }); + } + return Future>::MakeFinished(Continue()); + }, + [](const Status& err) -> Future> { return err; }, options); + }); + fut.AddCallback( + [this, scan_task](Result maybe_total_batches) mutable { + if (maybe_total_batches.ok()) { + plan_->query_context()->ScheduleTask( + [this, total_batches = *maybe_total_batches] { + return output_->InputFinished(this, total_batches); + }, + "SourceNode::InputFinished"); + } + scan_task.MarkFinished(maybe_total_batches.status()); + }, + options); return Status::OK(); } @@ -210,17 +204,10 @@ struct SourceNode : ExecNode, public TracedNode { to_finish.MarkFinished(); } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override { + Status StopProducingImpl() override { std::unique_lock lock(mutex_); stop_requested_ = true; - if (!started_) { - finished_.MarkFinished(); - } + return Status::OK(); } private: diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc b/cpp/src/arrow/compute/exec/swiss_join.cc index fee3c5f79db..d0c9dbe0977 100644 --- a/cpp/src/arrow/compute/exec/swiss_join.cc +++ b/cpp/src/arrow/compute/exec/swiss_join.cc @@ -1934,7 +1934,9 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), - [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + [&](ExecBatch batch) { + return output_batch_fn_(thread_id, std::move(batch)); + })); } } else { // We need to output matching pairs of rows from both sides of the join. @@ -1972,8 +1974,9 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, // RETURN_NOT_OK(materialize_[thread_id]->Append( keypayload_batch, num_matches_next, materialize_batch_ids, - materialize_key_ids, materialize_payload_ids, - [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + materialize_key_ids, materialize_payload_ids, [&](ExecBatch batch) { + return output_batch_fn_(thread_id, std::move(batch)); + })); } // For left-outer and full-outer joins output non-matches. @@ -1997,7 +2000,9 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), - [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + [&](ExecBatch batch) { + return output_batch_fn_(thread_id, std::move(batch)); + })); } } @@ -2014,7 +2019,7 @@ Status JoinProbeProcessor::OnFinished() { for (size_t i = 0; i < materialize_.size(); ++i) { JoinResultMaterialize& materialize = *materialize_[i]; RETURN_NOT_OK(materialize.Flush( - [&](ExecBatch batch) { output_batch_fn_(i, std::move(batch)); })); + [&](ExecBatch batch) { return output_batch_fn_(i, std::move(batch)); })); } return Status::OK(); @@ -2368,7 +2373,8 @@ class SwissJoin : public HashJoinImpl { Status status = local_states_[thread_id].materialize.AppendBuildOnly( num_output_rows, key_ids_buf.mutable_data(), payload_ids_buf.mutable_data(), [&](ExecBatch batch) { - output_batch_callback_(static_cast(thread_id), std::move(batch)); + return output_batch_callback_(static_cast(thread_id), + std::move(batch)); }); RETURN_NOT_OK(CancelIfNotOK(status)); if (!status.ok()) { @@ -2406,9 +2412,7 @@ class SwissJoin : public HashJoinImpl { num_produced_batches += materialize.num_produced_batches(); } - finished_callback_(num_produced_batches); - - return Status::OK(); + return finished_callback_(num_produced_batches); } Result KeyPayloadFromInput(int side, ExecBatch* input) { diff --git a/cpp/src/arrow/compute/exec/swiss_join.h b/cpp/src/arrow/compute/exec/swiss_join.h index bf3273c4e04..cbbd6d0ca2b 100644 --- a/cpp/src/arrow/compute/exec/swiss_join.h +++ b/cpp/src/arrow/compute/exec/swiss_join.h @@ -558,7 +558,7 @@ class JoinResultMaterialize { if (num_rows_appended < num_rows_to_append) { ExecBatch batch; ARROW_RETURN_NOT_OK(Flush(&batch)); - output_batch_fn(batch); + ARROW_RETURN_NOT_OK(output_batch_fn(batch)); num_rows_to_append -= num_rows_appended; offset += num_rows_appended; } else { @@ -613,7 +613,7 @@ class JoinResultMaterialize { if (num_rows_ > 0) { ExecBatch batch({}, num_rows_); ARROW_RETURN_NOT_OK(Flush(&batch)); - output_batch_fn(std::move(batch)); + ARROW_RETURN_NOT_OK(output_batch_fn(std::move(batch))); } return Status::OK(); } @@ -731,7 +731,7 @@ class JoinMatchIterator { // class JoinProbeProcessor { public: - using OutputBatchFn = std::function; + using OutputBatchFn = std::function; void Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, std::vector materialize, diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index e250ef53d4f..4ac3fec55a3 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -59,25 +59,24 @@ namespace compute { namespace { struct DummyNode : ExecNode { - DummyNode(ExecPlan* plan, NodeVector inputs, int num_outputs, + DummyNode(ExecPlan* plan, NodeVector inputs, bool is_sink, StartProducingFunc start_producing, StopProducingFunc stop_producing) - : ExecNode(plan, std::move(inputs), {}, dummy_schema(), num_outputs), + : ExecNode(plan, std::move(inputs), {}, (is_sink) ? nullptr : dummy_schema()), start_producing_(std::move(start_producing)), stop_producing_(std::move(stop_producing)) { input_labels_.resize(inputs_.size()); for (size_t i = 0; i < input_labels_.size(); ++i) { input_labels_[i] = std::to_string(i); } - finished_.MarkFinished(); } const char* kind_name() const override { return "Dummy"; } - void InputReceived(ExecNode* input, ExecBatch batch) override {} + Status InputReceived(ExecNode* input, ExecBatch batch) override { return Status::OK(); } - void ErrorReceived(ExecNode* input, Status error) override {} - - void InputFinished(ExecNode* input, int total_batches) override {} + Status InputFinished(ExecNode* input, int total_batches) override { + return Status::OK(); + } Status StartProducing() override { if (start_producing_) { @@ -88,36 +87,24 @@ struct DummyNode : ExecNode { } void PauseProducing(ExecNode* output, int32_t counter) override { - ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + ASSERT_NE(output_, nullptr) << "Sink nodes should not experience backpressure"; AssertIsOutput(output); } void ResumeProducing(ExecNode* output, int32_t counter) override { - ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; - AssertIsOutput(output); - } - - void StopProducing(ExecNode* output) override { - EXPECT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + ASSERT_NE(output_, nullptr) << "Sink nodes should not experience backpressure"; AssertIsOutput(output); } - void StopProducing() override { - if (started_) { - for (const auto& input : inputs_) { - input->StopProducing(this); - } - if (stop_producing_) { - stop_producing_(this); - } + Status StopProducingImpl() override { + if (stop_producing_) { + stop_producing_(this); } + return Status::OK(); } private: - void AssertIsOutput(ExecNode* output) { - auto it = std::find(outputs_.begin(), outputs_.end(), output); - ASSERT_NE(it, outputs_.end()); - } + void AssertIsOutput(ExecNode* output) { ASSERT_EQ(output->output(), nullptr); } std::shared_ptr dummy_schema() const { return schema({field("dummy", null())}); @@ -132,10 +119,10 @@ struct DummyNode : ExecNode { } // namespace ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, - int num_outputs, StartProducingFunc start_producing, + bool is_sink, StartProducingFunc start_producing, StopProducingFunc stop_producing) { auto node = - plan->EmplaceNode(plan, std::move(inputs), num_outputs, + plan->EmplaceNode(plan, std::move(inputs), is_sink, std::move(start_producing), std::move(stop_producing)); if (!label.empty()) { node->SetLabel(std::move(label)); @@ -175,14 +162,14 @@ ExecBatch ExecBatchFromJSON(const std::vector& types, Future<> StartAndFinish(ExecPlan* plan) { RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); return plan->finished(); } Future> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); auto collected_fut = CollectAsyncGenerator(gen); diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 189e33a94b2..d10ed783eb9 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -42,7 +42,8 @@ using StopProducingFunc = std::function; // Make a dummy node that has no execution behaviour ARROW_TESTING_EXPORT ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, - int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); + bool is_sink = false, StartProducingFunc = {}, + StopProducingFunc = {}); ARROW_TESTING_EXPORT ExecBatch ExecBatchFromJSON(const std::vector& types, std::string_view json); diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc b/cpp/src/arrow/compute/exec/tpch_node.cc index dca1e59821b..cba370ab014 100644 --- a/cpp/src/arrow/compute/exec/tpch_node.cc +++ b/cpp/src/arrow/compute/exec/tpch_node.cc @@ -499,8 +499,8 @@ bool TpchPseudotext::GenerateSentence(int64_t& offset, random::pcg32_fast& rng, class TpchTableGenerator { public: - using OutputBatchCallback = std::function; - using FinishedCallback = std::function; + using OutputBatchCallback = std::function; + using FinishedCallback = std::function; using GenerateFn = std::function; using ScheduleCallback = std::function; using AbortCallback = std::function; @@ -2514,11 +2514,11 @@ class SupplierGenerator : public TpchTableGenerator { int64_t batches_to_generate = (rows_to_generate_ + batch_size_ - 1) / batch_size_; int64_t batches_outputted_before_this_one = batches_outputted_.fetch_add(1); bool is_last_batch = batches_outputted_before_this_one == (batches_to_generate - 1); - output_callback_(std::move(eb)); + ARROW_RETURN_NOT_OK(output_callback_(std::move(eb))); if (is_last_batch) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + ARROW_RETURN_NOT_OK(finished_callback_(batches_outputted_.load())); return Status::OK(); } return schedule_callback_( @@ -2722,14 +2722,14 @@ class PartGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + ARROW_RETURN_NOT_OK(finished_callback_(batches_outputted_.load())); return Status::OK(); } return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + ARROW_RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -2784,14 +2784,14 @@ class PartSuppGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + ARROW_RETURN_NOT_OK(finished_callback_(batches_outputted_.load())); return Status::OK(); } return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + ARROW_RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -2898,11 +2898,11 @@ class CustomerGenerator : public TpchTableGenerator { int64_t batches_to_generate = (rows_to_generate_ + batch_size_ - 1) / batch_size_; int64_t batches_generated_before_this_one = batches_outputted_.fetch_add(1); bool is_last_batch = batches_generated_before_this_one == (batches_to_generate - 1); - output_callback_(std::move(eb)); + ARROW_RETURN_NOT_OK(output_callback_(std::move(eb))); if (is_last_batch) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + ARROW_RETURN_NOT_OK(finished_callback_(batches_outputted_.load())); return Status::OK(); } return schedule_callback_( @@ -3103,14 +3103,14 @@ class OrdersGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + ARROW_RETURN_NOT_OK(finished_callback_(batches_outputted_.load())); return Status::OK(); } return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + ARROW_RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -3165,7 +3165,7 @@ class LineitemGenerator : public TpchTableGenerator { if (batches_generated == batches_outputted_.load()) { bool expected = false; if (done_.compare_exchange_strong(expected, true)) - finished_callback_(batches_outputted_.load()); + ARROW_RETURN_NOT_OK(finished_callback_(batches_outputted_.load())); return Status::OK(); } // We may have generated but not outputted all of the batches. @@ -3173,7 +3173,7 @@ class LineitemGenerator : public TpchTableGenerator { [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); } ExecBatch batch = std::move(*maybe_batch); - output_callback_(std::move(batch)); + ARROW_RETURN_NOT_OK(output_callback_(std::move(batch))); batches_outputted_++; return schedule_callback_( [this](size_t thread_index) { return this->ProduceCallback(thread_index); }); @@ -3228,9 +3228,8 @@ class NationGenerator : public TpchTableGenerator { std::vector result; for (const int& col : column_indices_) result.push_back(fields[col]); ARROW_ASSIGN_OR_RAISE(ExecBatch batch, ExecBatch::Make(std::move(result))); - output_callback(std::move(batch)); - finished_callback(static_cast(1)); - return Status::OK(); + ARROW_RETURN_NOT_OK(output_callback(std::move(batch))); + return finished_callback(static_cast(1)); } std::shared_ptr schema() const override { return schema_; } @@ -3317,8 +3316,8 @@ class RegionGenerator : public TpchTableGenerator { std::vector result; for (const int& col : column_indices_) result.push_back(fields[col]); ARROW_ASSIGN_OR_RAISE(ExecBatch batch, ExecBatch::Make(std::move(result))); - output_callback(std::move(batch)); - finished_callback(static_cast(1)); + ARROW_RETURN_NOT_OK(output_callback(std::move(batch))); + ARROW_RETURN_NOT_OK(finished_callback(static_cast(1))); return Status::OK(); } @@ -3364,7 +3363,7 @@ class TpchNode : public ExecNode { public: TpchNode(ExecPlan* plan, const char* name, std::unique_ptr generator) - : ExecNode(plan, {}, {}, generator->schema(), /*num_outputs=*/1), + : ExecNode(plan, {}, {}, generator->schema()), name_(name), generator_(std::move(generator)) {} @@ -3374,24 +3373,18 @@ class TpchNode : public ExecNode { Unreachable("TPC-H node should never have any inputs"); } - [[noreturn]] void InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } - - [[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); } + [[noreturn]] Status InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } - [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); } + [[noreturn]] Status InputFinished(ExecNode*, int) override { NoInputs(); } Status StartProducing() override { - num_running_++; RETURN_NOT_OK(generator_->StartProducing( plan_->query_context()->max_concurrency(), - [this](ExecBatch batch) { this->OutputBatchCallback(std::move(batch)); }, - [this](int64_t num_batches) { this->FinishedCallback(num_batches); }, + [this](ExecBatch batch) { return this->OutputBatchCallback(std::move(batch)); }, + [this](int64_t num_batches) { return this->FinishedCallback(num_batches); }, [this](std::function func) -> Status { return this->ScheduleTaskCallback(std::move(func)); })); - if (--num_running_ == 0) { - finished_.MarkFinished(Status::OK()); - } return Status::OK(); } @@ -3402,49 +3395,40 @@ class TpchNode : public ExecNode { // TODO(ARROW-16087) } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override { - if (generator_->Abort()) finished_.MarkFinished(); + Status StopProducingImpl() override { + generator_->Abort(); + return Status::OK(); } - Future<> finished() override { return finished_; } - private: - void OutputBatchCallback(ExecBatch batch) { - outputs_[0]->InputReceived(this, std::move(batch)); + Status OutputBatchCallback(ExecBatch batch) { + return output_->InputReceived(this, std::move(batch)); } - void FinishedCallback(int64_t total_num_batches) { - outputs_[0]->InputFinished(this, static_cast(total_num_batches)); + Status FinishedCallback(int64_t total_num_batches) { + ARROW_RETURN_NOT_OK( + output_->InputFinished(this, static_cast(total_num_batches))); finished_generating_.store(true); + return Status::OK(); } Status ScheduleTaskCallback(std::function func) { if (finished_generating_.load()) return Status::OK(); - num_running_++; - return plan_->query_context()->ScheduleTask( + plan_->query_context()->ScheduleTask( [this, func](size_t thread_index) { Status status = func(thread_index); if (!status.ok()) { - StopProducing(); - ErrorIfNotOk(status); - } - if (--num_running_ == 0) { - finished_.MarkFinished(Status::OK()); + ARROW_RETURN_NOT_OK(StopProducing()); } return status; }, "TpchNode::GenerateAndProcessBatch"); + return Status::OK(); } const char* name_; std::unique_ptr generator_; std::atomic finished_generating_{false}; - std::atomic num_running_{0}; }; class TpchGenImpl : public TpchGen { diff --git a/cpp/src/arrow/compute/exec/tpch_node_test.cc b/cpp/src/arrow/compute/exec/tpch_node_test.cc index c125d7c8161..d3bae1c768d 100644 --- a/cpp/src/arrow/compute/exec/tpch_node_test.cc +++ b/cpp/src/arrow/compute/exec/tpch_node_test.cc @@ -106,7 +106,7 @@ void VerifyStringAndNumber_Single(std::string_view row, std::string_view prefix, const char* num_str_end = row.data() + row.size(); int64_t num = 0; // Parse the number out; note that it can be padded with NUL chars at the end - for (; *num_str && num_str < num_str_end; num_str++) { + for (; num_str < num_str_end && *num_str; num_str++) { num *= 10; ASSERT_TRUE(std::isdigit(*num_str)) << row << ", prefix=" << prefix << ", i=" << i; num += *num_str - '0'; @@ -309,7 +309,7 @@ void VerifyOneOf(const Datum& d, int32_t byte_width, for (int64_t i = 0; i < length; i++) { const char* row = col + i * byte_width; int32_t row_len = 0; - while (row[row_len] && row_len < byte_width) row_len++; + while (row_len < byte_width && row[row_len]) row_len++; std::string_view view(row, row_len); ASSERT_TRUE(possibilities.find(view) != possibilities.end()) << view << " is not a valid string."; @@ -631,7 +631,7 @@ TEST(TpchNode, AllTables) { } ASSERT_OK(plan->Validate()); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_OK(plan->finished().status()); for (int i = 0; i < kNumTables; i++) { auto fut = CollectAsyncGenerator(gens[i]); diff --git a/cpp/src/arrow/compute/exec/union_node.cc b/cpp/src/arrow/compute/exec/union_node.cc index f9241a431e3..ae967517231 100644 --- a/cpp/src/arrow/compute/exec/union_node.cc +++ b/cpp/src/arrow/compute/exec/union_node.cc @@ -50,8 +50,7 @@ class UnionNode : public ExecNode, public TracedNode { public: UnionNode(ExecPlan* plan, std::vector inputs) : ExecNode(plan, inputs, GetInputLabels(inputs), - /*output_schema=*/inputs[0]->output_schema(), - /*num_outputs=*/1) { + /*output_schema=*/inputs[0]->output_schema()) { bool counter_completed = input_count_.SetTotal(static_cast(inputs.size())); ARROW_DCHECK(counter_completed == false); } @@ -76,37 +75,23 @@ class UnionNode : public ExecNode, public TracedNode { return plan->EmplaceNode(plan, std::move(inputs)); } - void InputReceived(ExecNode* input, ExecBatch batch) override { + Status InputReceived(ExecNode* input, ExecBatch batch) override { NoteInputReceived(batch); ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - if (finished_.is_finished()) { - return; - } - outputs_[0]->InputReceived(this, std::move(batch)); - if (batch_count_.Increment()) { - finished_.MarkFinished(); - } - } - - void ErrorReceived(ExecNode* input, Status error) override { - DCHECK_EQ(input, inputs_[0]); - outputs_[0]->ErrorReceived(this, std::move(error)); - - StopProducing(); + return output_->InputReceived(this, std::move(batch)); } - void InputFinished(ExecNode* input, int total_batches) override { + Status InputFinished(ExecNode* input, int total_batches) override { ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); total_batches_.fetch_add(total_batches); if (input_count_.Increment()) { - outputs_[0]->InputFinished(this, total_batches_.load()); - if (batch_count_.SetTotal(total_batches_.load())) { - finished_.MarkFinished(); - } + return output_->InputFinished(this, total_batches_.load()); } + + return Status::OK(); } Status StartProducing() override { @@ -126,29 +111,9 @@ class UnionNode : public ExecNode, public TracedNode { } } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - if (batch_count_.Cancel()) { - finished_.MarkFinished(); - } - for (auto&& input : inputs_) { - input->StopProducing(this); - } - } - - void StopProducing() override { - if (batch_count_.Cancel()) { - finished_.MarkFinished(); - } - for (auto&& input : inputs_) { - input->StopProducing(this); - } - } - - Future<> finished() override { return finished_; } + Status StopProducingImpl() override { return Status::OK(); } private: - AtomicCounter batch_count_; AtomicCounter input_count_; std::atomic total_batches_{0}; }; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 50d8cd49aba..3bc2a03ddb6 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -143,7 +143,7 @@ Result GroupByUsingExecPlan(const BatchesWithSchema& input, .AddToPlan(plan.get())); RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); auto collected_fut = CollectAsyncGenerator(sink_gen); @@ -162,7 +162,7 @@ Result GroupByUsingExecPlan(const BatchesWithSchema& input, start_and_collect.MoveResult()); ArrayVector out_arrays(aggregates.size() + key_names.size()); - const auto& output_schema = plan->sources()[0]->outputs()[0]->output_schema(); + const auto& output_schema = plan->nodes()[0]->output()->output_schema(); for (size_t i = 0; i < out_arrays.size(); ++i) { std::vector> arrays(output_batches.size()); for (size_t j = 0; j < output_batches.size(); ++j) { diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 4c74cf1f62a..29e9ab1f0c4 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -529,15 +529,9 @@ class TeeNode : public compute::MapNode { const char* kind_name() const override { return "TeeNode"; } - void Finish(Status finish_st) override { - if (!finish_st.ok()) { - MapNode::Finish(std::move(finish_st)); - return; - } - dataset_writer_->Finish(); - } + void Finish() override { dataset_writer_->Finish(); } - Result DoTee(const compute::ExecBatch& batch) { + Result ProcessBatch(compute::ExecBatch batch) override { ARROW_ASSIGN_OR_RAISE(std::shared_ptr record_batch, batch.ToRecordBatch(output_schema())); ARROW_RETURN_NOT_OK(WriteNextBatch(std::move(record_batch), batch.guarantee)); @@ -549,19 +543,12 @@ class TeeNode : public compute::MapNode { return WriteBatch(batch, guarantee, write_options_, [this](std::shared_ptr next_batch, const PartitionPathFormat& destination) { - util::tracing::Span span; dataset_writer_->WriteRecordBatch( next_batch, destination.directory, destination.filename); return Status::OK(); }); } - void InputReceived(compute::ExecNode* input, compute::ExecBatch batch) override { - DCHECK_EQ(input, inputs_[0]); - auto func = [this](compute::ExecBatch batch) { return DoTee(std::move(batch)); }; - this->SubmitTask(std::move(func), std::move(batch)); - } - void Pause() { inputs_[0]->PauseProducing(this, ++backpressure_counter_); } void Resume() { inputs_[0]->ResumeProducing(this, ++backpressure_counter_); } diff --git a/cpp/src/arrow/dataset/scan_node.cc b/cpp/src/arrow/dataset/scan_node.cc index d37d3316b3c..15b82fb0761 100644 --- a/cpp/src/arrow/dataset/scan_node.cc +++ b/cpp/src/arrow/dataset/scan_node.cc @@ -118,9 +118,7 @@ class ScanNode : public cp::ExecNode, public cp::TracedNode { public: ScanNode(cp::ExecPlan* plan, ScanV2Options options, std::shared_ptr output_schema) - : cp::ExecNode(plan, {}, {}, std::move(output_schema), - /*num_outputs=*/1), - options_(options) {} + : cp::ExecNode(plan, {}, {}, std::move(output_schema)), options_(options) {} static Result NormalizeAndValidate(const ScanV2Options& options, compute::ExecContext* ctx) { @@ -181,9 +179,8 @@ class ScanNode : public cp::ExecNode, public cp::TracedNode { [[noreturn]] static void NoInputs() { Unreachable("no inputs; this should never be called"); } - [[noreturn]] void InputReceived(cp::ExecNode*, cp::ExecBatch) override { NoInputs(); } - [[noreturn]] void ErrorReceived(cp::ExecNode*, Status) override { NoInputs(); } - [[noreturn]] void InputFinished(cp::ExecNode*, int) override { NoInputs(); } + [[noreturn]] Status InputReceived(cp::ExecNode*, cp::ExecBatch) override { NoInputs(); } + [[noreturn]] Status InputFinished(cp::ExecNode*, int) override { NoInputs(); } Status Init() override { return Status::OK(); } @@ -224,12 +221,12 @@ class ScanNode : public cp::ExecNode, public cp::TracedNode { compute::ExecBatch evolved_batch, scan_->fragment_evolution->EvolveBatch(batch, node_->options_.columns, scan_->scan_request.columns)); - return node_->plan_->query_context()->ScheduleTask( + node_->plan_->query_context()->ScheduleTask( [node = node_, evolved_batch = std::move(evolved_batch)] { - node->outputs_[0]->InputReceived(node, std::move(evolved_batch)); - return Status::OK(); + return node->output_->InputReceived(node, std::move(evolved_batch)); }, "ScanNode::ProcessMorsel"); + return Status::OK(); } int cost() const override { return cost_; } @@ -330,11 +327,8 @@ class ScanNode : public cp::ExecNode, public cp::TracedNode { std::shared_ptr fragment_tasks = util::MakeThrottledAsyncTaskGroup( plan_->query_context()->async_scheduler(), options_.fragment_readahead + 1, - /*queue=*/nullptr, [this]() { - outputs_[0]->InputFinished(this, num_batches_.load()); - finished_.MarkFinished(); - return Status::OK(); - }); + /*queue=*/nullptr, + [this]() { return output_->InputFinished(this, num_batches_.load()); }); fragment_tasks->AddAsyncGenerator>( std::move(frag_gen), [this, fragment_tasks = @@ -368,12 +362,7 @@ class ScanNode : public cp::ExecNode, public cp::TracedNode { // TODO(ARROW-17755) } - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); - } - - void StopProducing() override {} + Status StopProducingImpl() override { return Status::OK(); } private: ScanV2Options options_; diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index f307787357d..4ca6594094a 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -448,7 +448,7 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( }) .AddToPlan(plan.get())); - RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); auto options = scan_options_; ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter)); @@ -466,13 +466,24 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( } }}; - return MakeMappedGenerator( + EnumeratedRecordBatchGenerator mapped_gen = MakeMappedGenerator( std::move(sink_gen), - [sink_gen, options, stop_producing, + [sink_gen, options, shared_fragments](const std::optional& batch) -> Future { return ToEnumeratedRecordBatch(batch, *options, *shared_fragments); }); + + return [mapped_gen = std::move(mapped_gen), plan = std::move(plan), + stop_producing = std::move(stop_producing)] { + auto next = mapped_gen(); + return next.Then([plan](const EnumeratedRecordBatch& value) { + if (IsIterationEnd(value)) { + return plan->finished().Then([value] { return value; }); + } + return Future::MakeFinished(value); + }); + }; } Result> AsyncScanner::TakeRows(const Array& indices) { diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index ea36db4b279..9ba8c0f68b1 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -2151,7 +2151,7 @@ struct TestPlan { Future> Run() { RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); auto collected_fut = CollectAsyncGenerator(sink_gen); @@ -2570,18 +2570,15 @@ TEST(ScanNode, MinimalEndToEnd) { // finally, pipe the project node into a sink node AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, - compute::MakeExecNode("ordered_sink", plan.get(), {project}, - compute::SinkNodeOptions{&sink_gen})); - - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); + ASSERT_OK(compute::MakeExecNode("ordered_sink", plan.get(), {project}, + compute::SinkNodeOptions{&sink_gen})); // translate sink_gen (async) to sink_reader (sync) std::shared_ptr sink_reader = compute::MakeGeneratorReader( schema({field("a * 2", int32())}), std::move(sink_gen), exec_context.memory_pool()); // start the ExecPlan - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); // collect sink_reader into a Table ASSERT_OK_AND_ASSIGN(auto collected, Table::FromRecordBatchReader(sink_reader.get())); @@ -2673,11 +2670,8 @@ TEST(ScanNode, MinimalScalarAggEndToEnd) { // finally, pipe the aggregate node into a sink node AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, - compute::MakeExecNode("sink", plan.get(), {aggregate}, - compute::SinkNodeOptions{&sink_gen})); - - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); + ASSERT_OK(compute::MakeExecNode("sink", plan.get(), {aggregate}, + compute::SinkNodeOptions{&sink_gen})); // translate sink_gen (async) to sink_reader (sync) std::shared_ptr sink_reader = @@ -2685,7 +2679,7 @@ TEST(ScanNode, MinimalScalarAggEndToEnd) { std::move(sink_gen), exec_context.memory_pool()); // start the ExecPlan - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); // collect sink_reader into a Table ASSERT_OK_AND_ASSIGN(auto collected, Table::FromRecordBatchReader(sink_reader.get())); @@ -2766,11 +2760,8 @@ TEST(ScanNode, MinimalGroupedAggEndToEnd) { // finally, pipe the aggregate node into a sink node AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, - compute::MakeExecNode("sink", plan.get(), {aggregate}, - compute::SinkNodeOptions{&sink_gen})); - - ASSERT_THAT(plan->sinks(), ElementsAre(sink)); + ASSERT_OK(compute::MakeExecNode("sink", plan.get(), {aggregate}, + compute::SinkNodeOptions{&sink_gen})); // translate sink_gen (async) to sink_reader (sync) std::shared_ptr sink_reader = compute::MakeGeneratorReader( @@ -2778,7 +2769,7 @@ TEST(ScanNode, MinimalGroupedAggEndToEnd) { exec_context.memory_pool()); // start the ExecPlan - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); // collect sink_reader into a Table ASSERT_OK_AND_ASSIGN(auto collected, Table::FromRecordBatchReader(sink_reader.get())); diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc index 7e1902eaae4..eca033fb0cc 100644 --- a/cpp/src/arrow/engine/substrait/function_test.cc +++ b/cpp/src/arrow/engine/substrait/function_test.cc @@ -140,7 +140,7 @@ void CheckValidTestCases(const std::vector& valid_cases) { std::shared_ptr output_table; ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, PlanFromTestCase(test_case, &output_table)); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); // Could also modify the Substrait plan with an emit to drop the leading columns @@ -161,12 +161,7 @@ void CheckErrorTestCases(const std::vector& error_cases) { std::shared_ptr
output_table; ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, PlanFromTestCase(test_case, &output_table)); - Status start_st = plan->StartProducing(); - // The plan can fail in start producing or when running the plan - if (!start_st.ok()) { - ASSERT_TRUE(start_st.IsInvalid()); - return; - } + plan->StartProducing(); ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished()); } } @@ -593,7 +588,7 @@ void CheckWholeAggregateCase(const AggregateTestCase& test_case) { std::shared_ptr plan = PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/false); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); ASSERT_OK_AND_ASSIGN(output_table, @@ -609,7 +604,7 @@ void CheckGroupedAggregateCase(const AggregateTestCase& test_case) { std::shared_ptr plan = PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/true); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); // The aggregate node's output is unpredictable so we sort by the key column diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 2916782fe04..47c371e9e8d 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1120,7 +1120,7 @@ TEST(Substrait, DeserializeWithConsumerFactory) { auto& prev_node = sink_node->inputs()[0]; ASSERT_STREQ(prev_node->kind_name(), "SourceNode"); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); } @@ -1129,14 +1129,14 @@ TEST(Substrait, DeserializeSinglePlanWithConsumerFactory) { ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json)); ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, DeserializePlan(*buf, compute::NullSinkNodeConsumer::Make())); - ASSERT_EQ(1, plan->sinks().size()); - compute::ExecNode* sink_node = plan->sinks()[0]; + ASSERT_EQ(2, plan->nodes().size()); + compute::ExecNode* sink_node = plan->nodes()[1]; ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode"); ASSERT_EQ(sink_node->num_inputs(), 1); auto& prev_node = sink_node->inputs()[0]; ASSERT_STREQ(prev_node->kind_name(), "SourceNode"); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); } @@ -1175,7 +1175,7 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) { auto& prev_node = sink_node->inputs()[0]; ASSERT_STREQ(prev_node->kind_name(), "SourceNode"); - ASSERT_OK(plan->StartProducing()); + plan->StartProducing(); ASSERT_FINISHES_OK(plan->finished()); } diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc b/cpp/src/arrow/flight/sql/example/acero_server.cc index c66510c7c25..43b69d669fc 100644 --- a/cpp/src/arrow/flight/sql/example/acero_server.cc +++ b/cpp/src/arrow/flight/sql/example/acero_server.cc @@ -200,7 +200,7 @@ class AceroFlightSqlServer : public FlightSqlServerBase { ARROW_LOG(INFO) << "DoGetStatement: executing plan " << plan->ToString(); - ARROW_RETURN_NOT_OK(plan->StartProducing()); + plan->StartProducing(); auto reader = std::make_shared(std::move(plan), std::move(consumer)); @@ -268,11 +268,13 @@ class AceroFlightSqlServer : public FlightSqlServerBase { ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, engine::DeserializePlan(*plan_buf, consumer)); std::shared_ptr output_schema; - for (compute::ExecNode* sink : plan->sinks()) { - // Force SinkNodeConsumer::Init to be called - ARROW_RETURN_NOT_OK(sink->StartProducing()); - output_schema = consumer->schema(); - break; + for (compute::ExecNode* possible_sink : plan->nodes()) { + if (possible_sink->is_sink()) { + // Force SinkNodeConsumer::Init to be called + ARROW_RETURN_NOT_OK(possible_sink->StartProducing()); + output_schema = consumer->schema(); + break; + } } if (!output_schema) { return Status::Invalid("Could not infer output schema"); diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index df6a883afe9..2d87971422c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2627,7 +2627,7 @@ cdef extern from "arrow/compute/exec/exec_plan.h" namespace "arrow::compute" nog @staticmethod CResult[shared_ptr[CExecPlan]] Make(CExecContext* exec_context) - CStatus StartProducing() + void StartProducing() CStatus Validate() CStatus StopProducing() diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 64ea6f5b5ef..cb1e90f56f3 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -100,7 +100,7 @@ class ExecPlanReader : public arrow::RecordBatchReader { // If this is the first batch getting pulled, tell the exec plan to // start producing if (plan_status_ == PLAN_NOT_STARTED) { - ARROW_RETURN_NOT_OK(StartProducing()); + StartProducing(); } // If we've closed the reader, keep sending nullptr @@ -134,7 +134,8 @@ class ExecPlanReader : public arrow::RecordBatchReader { *batch_out = batch_result.ValueUnsafe(); } else { batch_out->reset(); - StopProducing(); + plan_status_ = PLAN_FINISHED; + return plan_->finished().status(); } return arrow::Status::OK(); @@ -156,10 +157,9 @@ class ExecPlanReader : public arrow::RecordBatchReader { ExecPlanReaderStatus plan_status_; arrow::StopToken stop_token_; - arrow::Status StartProducing() { - ARROW_RETURN_NOT_OK(plan_->StartProducing()); + void StartProducing() { + plan_->StartProducing(); plan_status_ = PLAN_RUNNING; - return arrow::Status::OK(); } void StopProducing() { @@ -349,9 +349,8 @@ void ExecPlan_Write( StopIfNotOk(plan->Validate()); arrow::Status result = RunWithCapturedRIfPossibleVoid([&]() { - RETURN_NOT_OK(plan->StartProducing()); - RETURN_NOT_OK(plan->finished().status()); - return arrow::Status::OK(); + plan->StartProducing(); + return plan->finished().status(); }); StopIfNotOk(result); @@ -537,7 +536,7 @@ std::shared_ptr ExecPlan_run_substrait( } StopIfNotOk(plan->Validate()); - StopIfNotOk(plan->StartProducing()); + plan->StartProducing(); StopIfNotOk(plan->finished().status()); std::vector> all_batches;