diff --git a/cpp/examples/arrow/compute_register_example.cc b/cpp/examples/arrow/compute_register_example.cc index 6e5ff015387..0d09e2b6d61 100644 --- a/cpp/examples/arrow/compute_register_example.cc +++ b/cpp/examples/arrow/compute_register_example.cc @@ -95,8 +95,8 @@ class ExampleNode : public cp::ExecNode { void ResumeProducing(ExecNode* output) override {} void PauseProducing(ExecNode* output) override {} - void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); } - void StopProducing() override { inputs_[0]->StopProducing(); } + void StopProducing(ExecNode* output) override {} + void StopProducing() override {} void InputReceived(ExecNode* input, cp::ExecBatch batch) override {} void ErrorReceived(ExecNode* input, arrow::Status error) override {} diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 78280ba3fc0..78543ad65b7 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -1031,6 +1031,8 @@ ExecContext::ExecContext(MemoryPool* pool, ::arrow::internal::Executor* executor FunctionRegistry* func_registry) : pool_(pool), executor_(executor) { this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry; + DCHECK_NE(nullptr, executor_); + this->use_threads_ = executor_->GetCapacity() > 1; } CpuInfo* ExecContext::cpu_info() const { return CpuInfo::GetInstance(); } diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index faebddb7334..68d55dc40b2 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -34,6 +34,7 @@ #include "arrow/result.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" +#include "arrow/util/thread_pool.h" #include "arrow/util/type_fwd.h" #include "arrow/util/visibility.h" @@ -60,9 +61,10 @@ static constexpr int64_t kDefaultExecChunksize = UINT16_MAX; class ARROW_EXPORT ExecContext { public: // If no function registry passed, the default is used. - explicit ExecContext(MemoryPool* pool = default_memory_pool(), - ::arrow::internal::Executor* executor = NULLPTR, - FunctionRegistry* func_registry = NULLPTR); + explicit ExecContext( + MemoryPool* pool = default_memory_pool(), + ::arrow::internal::Executor* executor = ::arrow::internal::GetCpuThreadPool(), + FunctionRegistry* func_registry = NULLPTR); /// \brief The MemoryPool used for allocations, default is /// default_memory_pool(). diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index b4982ef111f..3f7d01b60fb 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -246,7 +246,6 @@ class ScalarAggregateNode : public ExecNode { if (input_counter_.Cancel()) { finished_.MarkFinished(); } - inputs_[0]->StopProducing(this); } Future<> finished() override { return finished_; } @@ -531,7 +530,7 @@ class GroupByNode : public ExecNode { auto executor = ctx_->executor(); for (int i = 0; i < num_output_batches; ++i) { - if (executor) { + if (ctx_->use_threads()) { // bail if StopProducing was called if (finished_.is_finished()) break; @@ -610,7 +609,6 @@ class GroupByNode : public ExecNode { if (output_counter_.Cancel()) { finished_.MarkFinished(); } - inputs_[0]->StopProducing(this); } void StopProducing() override { StopProducing(outputs_[0]); } diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index c383c6092af..55ab4d1cb4d 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -360,16 +360,11 @@ bool ExecNode::ErrorIfNotOk(Status status) { } MapNode::MapNode(ExecPlan* plan, std::vector inputs, - std::shared_ptr output_schema, bool async_mode) + std::shared_ptr output_schema, bool use_threads) : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"}, std::move(output_schema), - /*num_outputs=*/1) { - if (async_mode) { - executor_ = plan_->exec_context()->executor(); - } else { - executor_ = nullptr; - } -} + /*num_outputs=*/1), + use_threads_(use_threads) {} void MapNode::ErrorReceived(ExecNode* input, Status error) { DCHECK_EQ(input, inputs_[0]); @@ -406,13 +401,14 @@ void MapNode::StopProducing(ExecNode* output) { void MapNode::StopProducing() { EVENT(span_, "StopProducing"); - if (executor_) { + if (use_threads_) { + // If we are using tasks we may have a bunch of queued tasks that we should + // cancel this->stop_source_.RequestStop(); } if (input_counter_.Cancel()) { this->Finish(); } - inputs_[0]->StopProducing(this); } Future<> MapNode::finished() { return finished_; } @@ -436,15 +432,16 @@ void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, return Status::OK(); }; - if (executor_) { + if (use_threads_) { status = task_group_.AddTask([this, task]() -> Result> { - return this->executor_->Submit(this->stop_source_.token(), [this, task]() { - auto status = task(); - if (this->input_counter_.Increment()) { - this->Finish(status); - } - return status; - }); + return this->plan()->exec_context()->executor()->Submit( + this->stop_source_.token(), [this, task]() { + auto status = task(); + if (this->input_counter_.Increment()) { + this->Finish(status); + } + return status; + }); }); } else { status = task(); @@ -458,13 +455,12 @@ void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, if (input_counter_.Cancel()) { this->Finish(status); } - inputs_[0]->StopProducing(this); return; } } void MapNode::Finish(Status finish_st /*= Status::OK()*/) { - if (executor_) { + if (use_threads_) { task_group_.End().AddCallback([this, finish_st](const Status& st) { Status final_status = finish_st & st; this->finished_.MarkFinished(final_status); diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 85bfd5f756d..7fbc50c594f 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -273,7 +273,7 @@ class ARROW_EXPORT ExecNode { class MapNode : public ExecNode { public: MapNode(ExecPlan* plan, std::vector inputs, - std::shared_ptr output_schema, bool async_mode); + std::shared_ptr output_schema, bool use_threads); void ErrorReceived(ExecNode* input, Status error) override; @@ -303,7 +303,10 @@ class MapNode : public ExecNode { // The task group for the corresponding batches util::AsyncTaskGroup task_group_; - ::arrow::internal::Executor* executor_; + // If true then tasks will be spawned for each item + // + // If false the item will be processed immediately and synchronously + bool use_threads_; // Variable used to cancel remaining tasks in the executor StopSource stop_source_; diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index a0c75af3fa9..edeffb72d26 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -64,13 +64,24 @@ BatchesWithSchema GenerateBatchesFromString( return out_batches; } +std::unique_ptr SimpleExecContext( + bool parallel, std::shared_ptr<::arrow::internal::ThreadPool>* owned_thread_pool) { + if (parallel) { + return arrow::internal::make_unique(); + } else { + EXPECT_OK_AND_ASSIGN(*owned_thread_pool, ::arrow::internal::ThreadPool::Make(1)); + return arrow::internal::make_unique(default_memory_pool(), + owned_thread_pool->get()); + } +} + void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, const BatchesWithSchema& r_batches, const std::vector& left_keys, const std::vector& right_keys, const BatchesWithSchema& exp_batches, bool parallel = false) { - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool; + auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); @@ -904,8 +915,8 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel, const std::vector>& l, const std::vector>& r, int num_batches_l, int num_batches_r, std::shared_ptr* output) { - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool; + auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); @@ -1013,8 +1024,8 @@ TEST(HashJoin, Random) { #endif for (int test_id = 0; test_id < num_tests; ++test_id) { bool parallel = (rng.from_range(0, 1) == 1); - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool; + auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool); // Constraints RandomDataTypeConstraints type_constraints; @@ -1297,8 +1308,8 @@ void TestHashJoinDictionaryHelper( } } - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool; + auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); ASSERT_OK_AND_ASSIGN( ExecNode * l_source, @@ -1699,8 +1710,8 @@ TEST(HashJoin, DictNegative) { ExecBatch::Make({i == 2 ? datumSecondB : datumSecondA, i == 3 ? datumSecondB : datumSecondA})); - auto exec_ctx = - arrow::internal::make_unique(default_memory_pool(), nullptr); + std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool; + auto exec_ctx = SimpleExecContext(/*parallel=*/false, &owned_thread_pool); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); ASSERT_OK_AND_ASSIGN( ExecNode * l_source, @@ -1728,6 +1739,7 @@ TEST(HashJoin, DictNegative) { EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT( NotImplemented, ::testing::HasSubstr("Unifying differing dictionaries"), StartAndCollect(plan.get(), sink_gen)); + ASSERT_FINISHES_OK(plan->finished()); } } @@ -1786,8 +1798,8 @@ TEST(HashJoin, ResidualFilter) { input_right.schema = schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())}); - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool; + auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); AsyncGenerator> sink_gen; @@ -1863,9 +1875,8 @@ TEST(HashJoin, TrivialResidualFilter) { ])")}; input_right.schema = schema({field("r1", int32()), field("r_str", utf8())}); - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), - parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool; + auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); AsyncGenerator> sink_gen; diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index e176c701b65..773e42f2a7a 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -478,6 +478,11 @@ TEST(ExecPlanExecution, SourceSinkError) { ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), Finishes(Raises(StatusCode::Invalid, HasSubstr("Artificial")))); + // Note: the plan may or may not be finished at this point. When an error + // hits the sink node it starts to mark itself finished but before that it emits + // the error to the producer which will cause the above wait to finish (possibly + // before the plan has marked itself finished). So we wait for the plan to finish. + ASSERT_FINISHES_OK(plan->finished()); } TEST(ExecPlanExecution, SourceConsumingSink) { @@ -527,7 +532,7 @@ TEST(ExecPlanExecution, SourceConsumingSink) { } } -TEST(ExecPlanExecution, SourceTableConsumingSink) { +TEST(ExecPlanExecution, SourceTableSink) { for (bool slow : {false, true}) { SCOPED_TRACE(slow ? "slowed" : "unslowed"); @@ -549,11 +554,12 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) { ASSERT_OK(plan->StartProducing()); // Source should finish fairly quickly ASSERT_FINISHES_OK(source->finished()); - SleepABit(); - ASSERT_OK_AND_ASSIGN(auto actual, + ASSERT_OK_AND_ASSIGN(auto expected, TableFromExecBatches(basic_data.schema, basic_data.batches)); ASSERT_EQ(5, out->num_rows()); - AssertTablesEqual(*actual, *out); + ASSERT_OK_AND_ASSIGN(auto expected_sorted, SortTableOnAllFields(expected)); + ASSERT_OK_AND_ASSIGN(auto out_sorted, SortTableOnAllFields(out)); + AssertTablesEqual(*expected_sorted, *out_sorted); ASSERT_FINISHES_OK(plan->finished()); } } @@ -856,11 +862,11 @@ TEST(ExecPlanExecution, SourceGroupedSum) { }) .AddToPlan(plan.get())); - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0), Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON( {int64(), utf8()}, - parallel ? R"([[800, "alfa"], [1000, "beta"], [400, "gama"]])" - : R"([[8, "alfa"], [10, "beta"], [4, "gama"]])")})))); + parallel ? R"([[400, "gama"], [800, "alfa"], [1000, "beta"]])" + : R"([[4, "gama"], [8, "alfa"], [10, "beta"]])")})))); } } @@ -900,9 +906,9 @@ TEST(ExecPlanExecution, NestedSourceProjectGroupedSum) { auto input = MakeNestedBatches(); auto expected = ExecBatchFromJSON({int64(), boolean()}, R"([ - [null, true], + [5, null], [17, false], - [5, null] + [null, true] ])"); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); @@ -926,7 +932,7 @@ TEST(ExecPlanExecution, NestedSourceProjectGroupedSum) { }) .AddToPlan(plan.get())); - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0), Finishes(ResultWith(UnorderedElementsAreArray({expected})))); } } @@ -962,10 +968,10 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumFilter) { }) .AddToPlan(plan.get())); - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0), Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON( - {int64(), utf8()}, parallel ? R"([[3600, "alfa"], [2000, "beta"]])" - : R"([[36, "alfa"], [20, "beta"]])")})))); + {int64(), utf8()}, parallel ? R"([[2000, "beta"], [3600, "alfa"]])" + : R"([[20, "beta"], [36, "alfa"]])")})))); } } @@ -1123,7 +1129,7 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) { .AddToPlan(plan.get())); } - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 1), Finishes(ResultWith(UnorderedElementsAreArray({ ExecBatchFromJSON({int64(), utf8()}, R"([[500, "alfa"], [200, "beta"], [200, "gama"]])"), @@ -1211,7 +1217,7 @@ TEST(ExecPlanExecution, ScalarSourceGroupedSum) { }) .AddToPlan(plan.get())); - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0), Finishes(ResultWith(UnorderedElementsAreArray({ ExecBatchFromJSON({int64(), boolean()}, R"([[6, true], [18, false]])"), })))); @@ -1223,8 +1229,8 @@ TEST(ExecPlanExecution, SelfInnerHashJoinSink) { auto input = MakeGroupableBatches(); - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + auto exec_ctx = arrow::internal::make_unique(default_memory_pool()); + exec_ctx->set_use_threads(parallel); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); AsyncGenerator> sink_gen; @@ -1280,8 +1286,8 @@ TEST(ExecPlanExecution, SelfOuterHashJoinSink) { auto input = MakeGroupableBatches(); - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + auto exec_ctx = arrow::internal::make_unique(default_memory_pool()); + exec_ctx->set_use_threads(parallel); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); AsyncGenerator> sink_gen; diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 13564c736b5..d2525325854 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -53,7 +53,10 @@ class SinkNode : public ExecNode { util::BackpressureOptions backpressure) : ExecNode(plan, std::move(inputs), {"collected"}, {}, /*num_outputs=*/0), - producer_(MakeProducer(generator, std::move(backpressure))) {} + producer_(MakeProducer(generator, std::move(backpressure))) { + DCHECK_EQ(1, inputs_.size()); + output_schema_ = inputs_[0]->output_schema(); + } static Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { @@ -96,9 +99,9 @@ class SinkNode : public ExecNode { void StopProducing() override { EVENT(span_, "StopProducing"); - - Finish(); - inputs_[0]->StopProducing(this); + if (input_counter_.Cancel()) { + Finish(); + } } Future<> finished() override { return finished_; } @@ -128,7 +131,6 @@ class SinkNode : public ExecNode { if (input_counter_.Cancel()) { Finish(); } - inputs_[0]->StopProducing(this); } void InputFinished(ExecNode* input, int total_batches) override { @@ -140,9 +142,8 @@ class SinkNode : public ExecNode { protected: virtual void Finish() { - if (producer_.Close()) { - finished_.MarkFinished(); - } + producer_.Close(); + finished_.MarkFinished(); } AtomicCounter input_counter_; @@ -193,8 +194,9 @@ class ConsumingSinkNode : public ExecNode { void StopProducing() override { EVENT(span_, "StopProducing"); - Finish(Status::Invalid("ExecPlan was stopped early")); - inputs_[0]->StopProducing(this); + if (input_counter_.Cancel()) { + Finish(Status::OK()); + } } Future<> finished() override { return finished_; } @@ -218,7 +220,6 @@ class ConsumingSinkNode : public ExecNode { if (input_counter_.Cancel()) { Finish(std::move(consumption_status)); } - inputs_[0]->StopProducing(this); return; } @@ -234,8 +235,6 @@ class ConsumingSinkNode : public ExecNode { if (input_counter_.Cancel()) { Finish(std::move(error)); } - - inputs_[0]->StopProducing(this); } void InputFinished(ExecNode* input, int total_batches) override { diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 6b9a2ee9132..414656da3d7 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -88,14 +88,12 @@ struct SourceNode : ExecNode { CallbackOptions options; auto executor = plan()->exec_context()->executor(); - if (executor) { - // These options will transfer execution to the desired Executor if necessary. - // This can happen for in-memory scans where batches didn't require - // any CPU work to decode. Otherwise, parsing etc should have already - // been placed us on the desired Executor and no queues will be pushed to. - options.executor = executor; - options.should_schedule = ShouldSchedule::IfDifferentExecutor; - } + // These options will transfer execution to the desired Executor if necessary. + // This can happen for in-memory scans where batches didn't require + // any CPU work to decode. Otherwise, parsing etc should have already + // been placed us on the desired Executor and no queues will be pushed to. + options.executor = executor; + options.should_schedule = ShouldSchedule::IfDifferentExecutor; finished_ = Loop([this, executor, options] { std::unique_lock lock(mutex_); diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 7733ea46084..c9839a6c567 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -102,9 +102,6 @@ struct DummyNode : ExecNode { void StopProducing() override { if (started_) { - for (const auto& input : inputs_) { - input->StopProducing(this); - } if (stop_producing_) { stop_producing_(this); } @@ -129,6 +126,41 @@ struct DummyNode : ExecNode { bool started_ = false; }; +std::shared_ptr AnonymousSchema(const ExecBatch& batch) { + FieldVector fields; + for (int i = 0; i < batch.num_values(); i++) { + fields.push_back(field("f" + std::to_string(i), batch[i].type())); + } + return schema(std::move(fields)); +} + +Result> ExecBatchToAnonymousRecordBatch( + const ExecBatch& batch) { + return batch.ToRecordBatch(AnonymousSchema(batch)); +} + +Result> ExecBatchesToAnonymousTable( + const std::vector& batches) { + DCHECK_GT(batches.size(), 0); + std::shared_ptr schema = AnonymousSchema(batches[0]); + return TableFromExecBatches(schema, batches); +} + +std::vector TableToExecBatches(const Table& table) { + std::vector exec_batches; + std::shared_ptr sample_array = table.column(0); + for (int i = 0; i < sample_array->num_chunks(); i++) { + std::vector datums; + for (int j = 0; j < table.num_columns(); j++) { + DCHECK_EQ(sample_array->chunk(i)->length(), table.column(j)->chunk(i)->length()); + datums.push_back(table.column(j)->chunk(i)); + } + exec_batches.push_back( + ExecBatch(std::move(datums), sample_array->chunk(i)->length())); + } + return exec_batches; +} + } // namespace ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, @@ -181,6 +213,14 @@ Future> StartAndCollect( }); } +Future> StartAndCollectSortedByField( + ExecPlan* plan, AsyncGenerator> gen, int field_index) { + return StartAndCollect(plan, std::move(gen)) + .Then([field_index](const std::vector& batches) { + return SortBatchesByField(batches, field_index); + }); +} + BatchesWithSchema MakeBasicBatches() { BatchesWithSchema out; out.batches = { @@ -231,6 +271,24 @@ Result> SortTableOnAllFields(const std::shared_ptr
return tab_sorted.table(); } +Result SortBatchByField(const ExecBatch& batch, int field_index) { + ARROW_ASSIGN_OR_RAISE(auto sort_indices, + SortIndices(batch[field_index], SortOptions{})); + ARROW_ASSIGN_OR_RAISE(auto record_batch, ExecBatchToAnonymousRecordBatch(batch)); + ARROW_ASSIGN_OR_RAISE(auto sorted_data, Take(record_batch, sort_indices)); + return ExecBatch(*sorted_data.record_batch()); +} + +Result> SortBatchesByField(const std::vector& batches, + int field_index) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr
table, + ExecBatchesToAnonymousTable(batches)); + ARROW_ASSIGN_OR_RAISE(auto sort_indices, + SortIndices(table->column(field_index), SortOptions{})); + ARROW_ASSIGN_OR_RAISE(auto sorted_data, Take(table, sort_indices)); + return TableToExecBatches(*sorted_data.table()); +} + void AssertTablesEqual(const std::shared_ptr
& exp, const std::shared_ptr
& act) { ASSERT_EQ(exp->num_columns(), act->num_columns()); diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 4a46ca95839..a3676e49142 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -84,6 +84,10 @@ ARROW_TESTING_EXPORT Future> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen); +ARROW_TESTING_EXPORT +Future> StartAndCollectSortedByField( + ExecPlan* plan, AsyncGenerator> gen, int field_index); + ARROW_TESTING_EXPORT BatchesWithSchema MakeBasicBatches(); @@ -97,6 +101,13 @@ BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, ARROW_TESTING_EXPORT Result> SortTableOnAllFields(const std::shared_ptr
& tab); +ARROW_TESTING_EXPORT +Result SortBatchByField(const ExecBatch& batch, int field_index); + +ARROW_TESTING_EXPORT +Result> SortBatchesByField(const std::vector& batches, + int field_index); + ARROW_TESTING_EXPORT void AssertTablesEqual(const std::shared_ptr
& exp, const std::shared_ptr
& act); diff --git a/cpp/src/arrow/compute/exec/union_node.cc b/cpp/src/arrow/compute/exec/union_node.cc index 9232516cc6d..827421610d8 100644 --- a/cpp/src/arrow/compute/exec/union_node.cc +++ b/cpp/src/arrow/compute/exec/union_node.cc @@ -131,18 +131,12 @@ class UnionNode : public ExecNode { if (batch_count_.Cancel()) { finished_.MarkFinished(); } - for (auto&& input : inputs_) { - input->StopProducing(this); - } } void StopProducing() override { if (batch_count_.Cancel()) { finished_.MarkFinished(); } - for (auto&& input : inputs_) { - input->StopProducing(this); - } } Future<> finished() override { return finished_; } diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index 198cb84ff5e..3a63b8f6649 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -44,6 +44,7 @@ #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" namespace arrow { @@ -70,10 +71,12 @@ TEST(ExecContext, BasicWorkings) { // Now, let's customize all the things LoggingMemoryPool my_pool(default_memory_pool()); std::unique_ptr custom_reg = FunctionRegistry::Make(); - ExecContext ctx(&my_pool, /*executor=*/nullptr, custom_reg.get()); + ASSERT_OK_AND_ASSIGN(auto executor, ::arrow::internal::ThreadPool::Make(1)); + ExecContext ctx(&my_pool, executor.get(), custom_reg.get()); ASSERT_EQ(custom_reg.get(), ctx.func_registry()); ASSERT_EQ(&my_pool, ctx.memory_pool()); + ASSERT_EQ(executor.get(), ctx.executor()); ctx.set_exec_chunksize(1 << 20); ASSERT_EQ(1 << 20, ctx.exec_chunksize()); diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index e292cf4a9bc..eefcad87486 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -81,7 +81,8 @@ InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches, } Result InMemoryFragment::ScanBatchesAsync( - const std::shared_ptr& options) { + const std::shared_ptr& options, + ::arrow::internal::Executor* cpu_executor) { struct State { State(std::shared_ptr fragment, int64_t batch_size) : fragment(std::move(fragment)), diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 9f4fee52154..d12ee85c9d5 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -31,6 +31,7 @@ #include "arrow/util/macros.h" #include "arrow/util/mutex.h" #include "arrow/util/optional.h" +#include "arrow/util/thread_pool.h" namespace arrow { namespace dataset { @@ -55,9 +56,10 @@ class ARROW_DS_EXPORT Fragment : public std::enable_shared_from_this { /// The schema is cached after being read once, or may be specified at construction. Result> ReadPhysicalSchema(); - /// An asynchronous version of Scan + /// \brief Scans the fragment and returns the data as batches virtual Result ScanBatchesAsync( - const std::shared_ptr& options) = 0; + const std::shared_ptr& options, + ::arrow::internal::Executor* cpu_executor) = 0; /// \brief Count the number of rows in this fragment matching the filter using metadata /// only. That is, this method may perform I/O, but will not load data. @@ -119,7 +121,8 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { compute::Expression = compute::literal(true)); Result ScanBatchesAsync( - const std::shared_ptr& options) override; + const std::shared_ptr& options, + ::arrow::internal::Executor* cpu_executor) override; Future> CountRows( compute::Expression predicate, const std::shared_ptr& options) override; diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index f4551c27590..cedaa7008f3 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -118,9 +118,10 @@ Result> FileFragment::ReadPhysicalSchemaImpl() { } Result FileFragment::ScanBatchesAsync( - const std::shared_ptr& options) { + const std::shared_ptr& options, + ::arrow::internal::Executor* cpu_executor) { auto self = std::dynamic_pointer_cast(shared_from_this()); - return format_->ScanBatchesAsync(options, self); + return format_->ScanBatchesAsync(options, self, cpu_executor); } Future> FileFragment::CountRows( @@ -332,13 +333,10 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { util::SerializedAsyncTaskGroup task_group_; }; -} // namespace - -Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options, - std::shared_ptr scanner) { +Future<> DoWriteAsync(const FileSystemDatasetWriteOptions& write_options, + std::shared_ptr scanner, + ::arrow::internal::Executor* cpu_executor) { const io::IOContext& io_context = scanner->options()->io_context; - auto cpu_executor = - scanner->options()->use_threads ? ::arrow::internal::GetCpuThreadPool() : nullptr; std::shared_ptr exec_context = std::make_shared(io_context.pool(), cpu_executor); @@ -366,7 +364,26 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio .AddToPlan(plan.get())); RETURN_NOT_OK(plan->StartProducing()); - return plan->finished().status(); + return plan->finished().Then([plan, exec_context]() { + // Keep plan and exec_context alive until the plan is finished + }); +} + +} // namespace + +Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options, + std::shared_ptr scanner) { + if (scanner->options()->use_threads) { + return DoWriteAsync(std::move(write_options), std::move(scanner), + ::arrow::internal::GetCpuThreadPool()) + .status(); + } else { + return ::arrow::internal::SerialExecutor::RunInSerialExecutor< + ::arrow::internal::Empty>( + [write_options, scanner](::arrow::internal::Executor* executor) { + return DoWriteAsync(std::move(write_options), std::move(scanner), executor); + }); + } } Result MakeWriteNode(compute::ExecPlan* plan, diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 07b156778f6..6df5457a4f0 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -147,9 +147,26 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this> Inspect(const FileSource& source) const = 0; + /// \brief Scan the fragment returning the data as a stream of batches + /// + /// Typically it is expected that each generator task will involve some amount + /// of I/O work and some amount of CPU work. The format is responsible for transferring + /// the work back and forth between the I/O executor (obtained from the filesystem) and + /// CPU executor (passed in directly) as needed. + /// + /// If the underlying format does not support asynchronous operation then the correct + /// workaround is to schedule the entire task on the I/O executor. There is typically + /// no need to transfer back to the CPU executor as this will be done automatically by + /// the scan node. + /// + /// If the underlying format spends a lot of time blocked on I/O then it may be + /// neccesary to make the I/O executor considerably larger than the # of cores on the + /// device to ensure maximum parallelism. This could potentially lead to excess context + /// switching which is why an async approach is preferred. virtual Result ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const = 0; + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const = 0; virtual Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, @@ -182,7 +199,8 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this ScanBatchesAsync( - const std::shared_ptr& options) override; + const std::shared_ptr& options, + ::arrow::internal::Executor* cpu_executor) override; Future> CountRows( compute::Expression predicate, const std::shared_ptr& options) override; diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 1cc7957083f..3c57b2a1e30 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -46,7 +46,6 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; using internal::Executor; -using internal::SerialExecutor; namespace dataset { @@ -245,11 +244,11 @@ Result> CsvFileFormat::Inspect(const FileSource& source) Result CsvFileFormat::ScanBatchesAsync( const std::shared_ptr& scan_options, - const std::shared_ptr& file) const { + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const { auto this_ = checked_pointer_cast(shared_from_this()); auto source = file->source(); - auto reader_fut = - OpenReaderAsync(source, *this, scan_options, ::arrow::internal::GetCpuThreadPool()); + auto reader_fut = OpenReaderAsync(source, *this, scan_options, cpu_executor); return GeneratorFromReader(std::move(reader_fut), scan_options->batch_size); } diff --git a/cpp/src/arrow/dataset/file_csv.h b/cpp/src/arrow/dataset/file_csv.h index 83dbb88b85f..7fdb82afbe4 100644 --- a/cpp/src/arrow/dataset/file_csv.h +++ b/cpp/src/arrow/dataset/file_csv.h @@ -55,7 +55,8 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { Result ScanBatchesAsync( const std::shared_ptr& scan_options, - const std::shared_ptr& file) const override; + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const override; Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, diff --git a/cpp/src/arrow/dataset/file_csv_test.cc b/cpp/src/arrow/dataset/file_csv_test.cc index e9088859005..3088d4517d6 100644 --- a/cpp/src/arrow/dataset/file_csv_test.cc +++ b/cpp/src/arrow/dataset/file_csv_test.cc @@ -99,7 +99,9 @@ class TestCsvFileFormat : public FileFormatFixtureMixin, } RecordBatchIterator Batches(Fragment* fragment) { - EXPECT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_)); + EXPECT_OK_AND_ASSIGN( + auto batch_gen, + fragment->ScanBatchesAsync(opts_, ::arrow::internal::GetCpuThreadPool())); return MakeGeneratorIterator(batch_gen); } }; diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc index e386c7dea8d..46100003f93 100644 --- a/cpp/src/arrow/dataset/file_ipc.cc +++ b/cpp/src/arrow/dataset/file_ipc.cc @@ -120,7 +120,8 @@ Result> IpcFileFormat::Inspect(const FileSource& source) Result IpcFileFormat::ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const { + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const { auto self = shared_from_this(); auto source = file->source(); auto open_reader = OpenReaderAsync(source); @@ -143,10 +144,10 @@ Result IpcFileFormat::ScanBatchesAsync( RecordBatchGenerator generator; if (ipc_scan_options->cache_options) { // Transferring helps performance when coalescing - ARROW_ASSIGN_OR_RAISE(generator, reader->GetRecordBatchGenerator( - /*coalesce=*/true, options->io_context, - *ipc_scan_options->cache_options, - ::arrow::internal::GetCpuThreadPool())); + ARROW_ASSIGN_OR_RAISE(generator, + reader->GetRecordBatchGenerator( + /*coalesce=*/true, options->io_context, + *ipc_scan_options->cache_options, cpu_executor)); } else { ARROW_ASSIGN_OR_RAISE(generator, reader->GetRecordBatchGenerator( /*coalesce=*/false, options->io_context)); diff --git a/cpp/src/arrow/dataset/file_ipc.h b/cpp/src/arrow/dataset/file_ipc.h index 29ce6be61d6..a0079424ec6 100644 --- a/cpp/src/arrow/dataset/file_ipc.h +++ b/cpp/src/arrow/dataset/file_ipc.h @@ -54,7 +54,8 @@ class ARROW_DS_EXPORT IpcFileFormat : public FileFormat { Result ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const override; + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const override; Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index 7422ee9f527..9f01a7ddfc9 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -167,7 +167,8 @@ TEST_P(TestIpcFileFormatScan, FragmentScanOptions) { fragment_scan_options->options = std::make_shared(); fragment_scan_options->options->max_recursion_depth = 0; opts_->fragment_scan_options = fragment_scan_options; - ASSERT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_)); + ASSERT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync( + opts_, ::arrow::internal::GetCpuThreadPool())); ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(batch_gen)); } INSTANTIATE_TEST_SUITE_P(TestScan, TestIpcFileFormatScan, diff --git a/cpp/src/arrow/dataset/file_orc.cc b/cpp/src/arrow/dataset/file_orc.cc index 49102f3deae..46c8dd21ee6 100644 --- a/cpp/src/arrow/dataset/file_orc.cc +++ b/cpp/src/arrow/dataset/file_orc.cc @@ -154,7 +154,8 @@ Result> OrcFileFormat::Inspect(const FileSource& source) Result OrcFileFormat::ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const { + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const { // TODO investigate "true" async version // (https://issues.apache.org/jira/browse/ARROW-13795) ARROW_ASSIGN_OR_RAISE(auto task_iter, OrcScanTaskIterator::Make(options, file)); diff --git a/cpp/src/arrow/dataset/file_orc.h b/cpp/src/arrow/dataset/file_orc.h index 5bbe4df24ad..3805f81dddc 100644 --- a/cpp/src/arrow/dataset/file_orc.h +++ b/cpp/src/arrow/dataset/file_orc.h @@ -53,7 +53,8 @@ class ARROW_DS_EXPORT OrcFileFormat : public FileFormat { Result ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const override; + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const override; Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index cdee1f00684..7133a63140a 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -408,7 +408,8 @@ Future> ParquetFileFormat::GetReader Result ParquetFileFormat::ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const { + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const { auto parquet_fragment = checked_pointer_cast(file); std::vector row_groups; bool pre_filtered = false; @@ -442,10 +443,9 @@ Result ParquetFileFormat::ScanBatchesAsync( // Assume 1 row group corresponds to 1 batch (this factor could be // improved by looking at metadata) int row_group_readahead = options->batch_readahead; - ARROW_ASSIGN_OR_RAISE( - auto generator, reader->GetRecordBatchGenerator( - reader, row_groups, column_projection, - ::arrow::internal::GetCpuThreadPool(), row_group_readahead)); + ARROW_ASSIGN_OR_RAISE(auto generator, reader->GetRecordBatchGenerator( + reader, row_groups, column_projection, + cpu_executor, row_group_readahead)); return generator; }; return MakeFromFuture(GetReaderAsync(parquet_fragment->source(), options) diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 571ec816667..ab47822ac93 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -97,7 +97,8 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { Result ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const override; + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const override; Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index d5c7a0b9850..ec3fa28507c 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -122,7 +122,9 @@ class ParquetFormatHelper { class TestParquetFileFormat : public FileFormatFixtureMixin { public: RecordBatchIterator Batches(Fragment* fragment) { - EXPECT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_)); + EXPECT_OK_AND_ASSIGN( + auto batch_gen, + fragment->ScanBatchesAsync(opts_, ::arrow::internal::GetCpuThreadPool())); return MakeGeneratorIterator(batch_gen); } @@ -586,8 +588,10 @@ TEST_P(TestParquetFileFormatScan, ExplicitRowGroupSelection) { SetFilter(greater(field_ref("i64"), literal(3))); CountRowsAndBatchesInScan(row_groups_fragment({2, 3, 4, 5}), 4 + 5 + 6, 3); - ASSERT_OK_AND_ASSIGN(auto batch_gen, - row_groups_fragment({kNumRowGroups + 1})->ScanBatchesAsync(opts_)); + ASSERT_OK_AND_ASSIGN( + auto batch_gen, + row_groups_fragment({kNumRowGroups + 1}) + ->ScanBatchesAsync(opts_, ::arrow::internal::GetCpuThreadPool())); Status scan_status = CollectAsyncGenerator(batch_gen).status(); EXPECT_RAISES_WITH_MESSAGE_THAT( diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index cc89c163cb7..db3ba88eb42 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -91,7 +91,8 @@ class MockFileFormat : public FileFormat { public: Result ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& file) const override { + const std::shared_ptr& file, + ::arrow::internal::Executor* cpu_executor) const override { auto sch = schema({field("i32", int32())}); RecordBatchVector batches; for (int i = 0; i < kNumBatches; i++) { @@ -119,7 +120,8 @@ class MockFileFormat : public FileFormat { TEST(FileFormat, ScanAsync) { MockFileFormat format; auto scan_options = std::make_shared(); - ASSERT_OK_AND_ASSIGN(auto batch_gen, format.ScanBatchesAsync(scan_options, nullptr)); + ASSERT_OK_AND_ASSIGN(auto batch_gen, + format.ScanBatchesAsync(scan_options, nullptr, nullptr)); ASSERT_FINISHES_OK_AND_ASSIGN(auto batches, CollectAsyncGenerator(batch_gen)); ASSERT_EQ(kNumBatches, static_cast(batches.size())); for (int i = 0; i < kNumBatches; i++) { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index b958f7b9e62..14418a9afda 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -186,6 +186,7 @@ class AsyncScanner : public Scanner, public std::enable_shared_from_this> TakeRows(const Array& indices) override; Result> Head(int64_t num_rows) override; Result> ToTable() override; + Future CountRowsAsync(Executor* cpu_executor); Result CountRows() override; Result> ToRecordBatchReader() override; const std::shared_ptr& dataset() const override; @@ -194,8 +195,7 @@ class AsyncScanner : public Scanner, public std::enable_shared_from_this ScanBatchesAsync(Executor* executor); Future<> VisitBatchesAsync(std::function visitor, Executor* executor); - Result ScanBatchesUnorderedAsync( - Executor* executor, bool sequence_fragments = false); + Result ScanBatchesUnorderedAsync(Executor* executor); Future> ToTableAsync(Executor* executor); Result GetFragments() const; @@ -205,8 +205,9 @@ class AsyncScanner : public Scanner, public std::enable_shared_from_this FragmentToBatches( const Enumerated>& fragment, - const std::shared_ptr& options) { - ARROW_ASSIGN_OR_RAISE(auto batch_gen, fragment.value->ScanBatchesAsync(options)); + const std::shared_ptr& options, Executor* cpu_executor) { + ARROW_ASSIGN_OR_RAISE(auto batch_gen, + fragment.value->ScanBatchesAsync(options, cpu_executor)); ArrayVector columns; for (const auto& field : options->dataset_schema->fields()) { // TODO(ARROW-7051): use helper to make empty batch @@ -228,11 +229,12 @@ Result FragmentToBatches( } Result> FragmentsToBatches( - FragmentGenerator fragment_gen, const std::shared_ptr& options) { + FragmentGenerator fragment_gen, const std::shared_ptr& options, + Executor* cpu_executor) { auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); return MakeMappedGenerator(std::move(enumerated_fragment_gen), [=](const Enumerated>& fragment) { - return FragmentToBatches(fragment, options); + return FragmentToBatches(fragment, options, cpu_executor); }); } @@ -248,13 +250,12 @@ class OneShotFragment : public Fragment { return Status::OK(); } Result ScanBatchesAsync( - const std::shared_ptr& options) override { + const std::shared_ptr& options, Executor* cpu_executor) override { RETURN_NOT_OK(CheckConsumed()); ARROW_ASSIGN_OR_RAISE( auto background_gen, MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor())); - return MakeTransferredGenerator(std::move(background_gen), - ::arrow::internal::GetCpuThreadPool()); + return MakeTransferredGenerator(std::move(background_gen), cpu_executor); } std::string type_name() const override { return "one-shot"; } @@ -276,20 +277,18 @@ Result AsyncScanner::GetFragments() const { } Result AsyncScanner::ScanBatches() { - ARROW_ASSIGN_OR_RAISE(auto batches_gen, - ScanBatchesAsync(::arrow::internal::GetCpuThreadPool())); - return MakeGeneratorIterator(std::move(batches_gen)); + return SerialExecutor::RunGeneratorInSerialExecutor( + [this](Executor* executor) { return ScanBatchesAsync(executor); }); } Result AsyncScanner::ScanBatchesUnordered() { - ARROW_ASSIGN_OR_RAISE(auto batches_gen, - ScanBatchesUnorderedAsync(::arrow::internal::GetCpuThreadPool())); - return MakeGeneratorIterator(std::move(batches_gen)); + return SerialExecutor::RunGeneratorInSerialExecutor( + [this](Executor* executor) { return ScanBatchesUnorderedAsync(executor); }); } Result> AsyncScanner::ToTable() { - auto table_fut = ToTableAsync(::arrow::internal::GetCpuThreadPool()); - return table_fut.result(); + return SerialExecutor::RunInSerialExecutor>( + [this](Executor* executor) { return ToTableAsync(executor); }); } Result AsyncScanner::ScanBatchesUnorderedAsync() { @@ -314,13 +313,8 @@ Result ToEnumeratedRecordBatch( } Result AsyncScanner::ScanBatchesUnorderedAsync( - Executor* cpu_executor, bool sequence_fragments) { - if (!scan_options_->use_threads) { - cpu_executor = nullptr; - } - + Executor* cpu_executor) { RETURN_NOT_OK(NormalizeScanOptions(scan_options_, dataset_->schema())); - auto exec_context = std::make_shared(scan_options_->pool, cpu_executor); @@ -337,8 +331,7 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( RETURN_NOT_OK( compute::Declaration::Sequence( { - {"scan", ScanNodeOptions{dataset_, scan_options_, backpressure.toggle, - sequence_fragments}}, + {"scan", ScanNodeOptions{dataset_, scan_options_, backpressure.toggle}}, {"filter", compute::FilterNodeOptions{scan_options_->filter}}, {"augmented_project", compute::ProjectNodeOptions{std::move(exprs), std::move(names)}}, @@ -484,8 +477,7 @@ Result AsyncScanner::ScanBatchesAsync() { Result AsyncScanner::ScanBatchesAsync( Executor* cpu_executor) { - ARROW_ASSIGN_OR_RAISE(auto unordered, ScanBatchesUnorderedAsync( - cpu_executor, /*sequence_fragments=*/true)); + ARROW_ASSIGN_OR_RAISE(auto unordered, ScanBatchesUnorderedAsync(cpu_executor)); // We need an initial value sentinel, so we use one with fragment.index < 0 auto is_before_any = [](const EnumeratedRecordBatch& batch) { return batch.fragment.index < 0; @@ -597,33 +589,40 @@ Future> AsyncScanner::ToTableAsync(Executor* cpu_executor }); } -Result AsyncScanner::CountRows() { +Future AsyncScanner::CountRowsAsync(Executor* cpu_executor) { ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); - auto cpu_executor = - scan_options_->use_threads ? ::arrow::internal::GetCpuThreadPool() : nullptr; - compute::ExecContext exec_context(scan_options_->pool, cpu_executor); + std::shared_ptr exec_context = + std::make_shared(scan_options_->pool, cpu_executor); - ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get())); // Drop projection since we only need to count rows - const auto options = std::make_shared(*scan_options_); + auto options = std::make_shared(*scan_options_); ARROW_ASSIGN_OR_RAISE(auto empty_projection, ProjectionDescr::FromNames(std::vector(), *scan_options_->dataset_schema)); SetProjection(options.get(), empty_projection); - std::atomic total{0}; + struct CountRowsState { + explicit CountRowsState(std::shared_ptr options) + : options(std::move(options)) {} + + std::shared_ptr options; + std::atomic total{0}; + }; + std::shared_ptr state = std::make_shared(options); + CountRowsState* state_ptr = state.get(); fragment_gen = MakeMappedGenerator( - std::move(fragment_gen), [&](const std::shared_ptr& fragment) { - return fragment->CountRows(options->filter, options) + std::move(fragment_gen), [state_ptr](const std::shared_ptr& fragment) { + return fragment->CountRows(state_ptr->options->filter, state_ptr->options) .Then([&, fragment](util::optional fast_count) mutable -> std::shared_ptr { if (fast_count) { // fast path: got row count directly; skip scanning this fragment - total += *fast_count; - return std::make_shared(options->dataset_schema, - RecordBatchVector{}); + state_ptr->total += *fast_count; + return std::make_shared( + state_ptr->options->dataset_schema, RecordBatchVector{}); } // slow path: actually filter this fragment's batches @@ -636,10 +635,10 @@ Result AsyncScanner::CountRows() { RETURN_NOT_OK( compute::Declaration::Sequence( { - {"scan", ScanNodeOptions{std::make_shared( - scan_options_->dataset_schema, - std::move(fragment_gen)), - options}}, + {"scan", + ScanNodeOptions{std::make_shared(options->dataset_schema, + std::move(fragment_gen)), + options}}, {"project", compute::ProjectNodeOptions{{options->filter}, {"mask"}}}, {"aggregate", compute::AggregateNodeOptions{{compute::internal::Aggregate{ "sum", nullptr}}, @@ -650,13 +649,19 @@ Result AsyncScanner::CountRows() { .AddToPlan(plan.get())); RETURN_NOT_OK(plan->StartProducing()); - auto maybe_slow_count = sink_gen().result(); - plan->finished().Wait(); + return plan->finished().Then( + [sink_gen, state, plan, exec_context]() -> Result { + auto maybe_slow_count = sink_gen().result(); + ARROW_ASSIGN_OR_RAISE(auto slow_count, maybe_slow_count); + state->total += slow_count->values[0].scalar_as().value; - ARROW_ASSIGN_OR_RAISE(auto slow_count, maybe_slow_count); - total += slow_count->values[0].scalar_as().value; + return state->total.load(); + }); +} - return total.load(); +Result AsyncScanner::CountRows() { + return SerialExecutor::RunInSerialExecutor( + [this](Executor* executor) { return CountRowsAsync(executor); }); } Result> AsyncScanner::ToRecordBatchReader() { @@ -831,7 +836,6 @@ Result MakeScanNode(compute::ExecPlan* plan, auto scan_options = scan_node_options.scan_options; auto dataset = scan_node_options.dataset; const auto& backpressure_toggle = scan_node_options.backpressure_toggle; - bool require_sequenced_output = scan_node_options.require_sequenced_output; RETURN_NOT_OK(NormalizeScanOptions(scan_options, dataset->schema())); @@ -841,17 +845,12 @@ Result MakeScanNode(compute::ExecPlan* plan, auto fragment_gen = MakeVectorGenerator(std::move(fragments_vec)); ARROW_ASSIGN_OR_RAISE(auto batch_gen_gen, - FragmentsToBatches(std::move(fragment_gen), scan_options)); + FragmentsToBatches(std::move(fragment_gen), scan_options, + plan->exec_context()->executor())); AsyncGenerator merged_batch_gen; - if (require_sequenced_output) { - ARROW_ASSIGN_OR_RAISE(merged_batch_gen, - MakeSequencedMergedGenerator(std::move(batch_gen_gen), - scan_options->fragment_readahead)); - } else { - merged_batch_gen = - MakeMergedGenerator(std::move(batch_gen_gen), scan_options->fragment_readahead); - } + merged_batch_gen = + MakeMergedGenerator(std::move(batch_gen_gen), scan_options->fragment_readahead); auto batch_gen = MakeReadaheadGenerator(std::move(merged_batch_gen), scan_options->fragment_readahead); diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 013ba092b0c..3271b86a05a 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -404,17 +404,24 @@ class ARROW_DS_EXPORT ScanNodeOptions : public compute::ExecNodeOptions { public: explicit ScanNodeOptions( std::shared_ptr dataset, std::shared_ptr scan_options, - std::shared_ptr backpressure_toggle = NULLPTR, - bool require_sequenced_output = false) + std::shared_ptr backpressure_toggle = NULLPTR) : dataset(std::move(dataset)), scan_options(std::move(scan_options)), - backpressure_toggle(std::move(backpressure_toggle)), - require_sequenced_output(require_sequenced_output) {} + backpressure_toggle(std::move(backpressure_toggle)) {} + /// \brief The dataset to scan std::shared_ptr dataset; + /// \brief Instructions on how to scan the dataset. + /// + /// The use_threads option will be ignored. The scanner will always + /// use the exec plan's executor (which may be a serial executor) std::shared_ptr scan_options; + /// \brief An optional toggle to allow pausing the scan. + /// + /// The node will cease reading from the scanner while the toggle is closed. + /// The scanner will continue until its readahead queue has filled up at which + /// point scanning will pause. std::shared_ptr backpressure_toggle; - bool require_sequenced_output; }; /// @} diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index dc72bf638d1..11f67a3a303 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -48,6 +48,7 @@ using testing::UnorderedElementsAreArray; namespace arrow { +using internal::Executor; using internal::GetCpuThreadPool; using internal::Iota; @@ -481,8 +482,8 @@ class CountRowsOnlyFragment : public InMemoryFragment { } return Future>::MakeFinished(sum); } - Result ScanBatchesAsync( - const std::shared_ptr&) override { + Result ScanBatchesAsync(const std::shared_ptr&, + Executor* cpu_executor) override { return Status::Invalid("Don't scan me!"); } }; @@ -495,8 +496,8 @@ class ScanOnlyFragment : public InMemoryFragment { compute::Expression predicate, const std::shared_ptr&) override { return Future>::MakeFinished(util::nullopt); } - Result ScanBatchesAsync( - const std::shared_ptr&) override { + Result ScanBatchesAsync(const std::shared_ptr&, + Executor*) override { return MakeVectorGenerator(record_batches_); } }; @@ -589,7 +590,7 @@ class FailingFragment : public InMemoryFragment { public: using InMemoryFragment::InMemoryFragment; Result ScanBatchesAsync( - const std::shared_ptr& options) override { + const std::shared_ptr& options, Executor*) override { struct { Future> operator()() { if (index > 16) { @@ -613,7 +614,7 @@ class FailingScanFragment : public InMemoryFragment { // There are two places to fail - during iteration (covered by FailingFragment) or at // the initial scan (covered here) Result ScanBatchesAsync( - const std::shared_ptr& options) override { + const std::shared_ptr& options, Executor*) override { return Status::Invalid("Oh no, we failed!"); } }; @@ -768,7 +769,7 @@ class ControlledFragment : public Fragment { std::string type_name() const override { return "scanner_test.cc::ControlledFragment"; } Result ScanBatchesAsync( - const std::shared_ptr& options) override { + const std::shared_ptr& options, Executor*) override { return tracking_generator_; }; diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 3f826fa09c9..6524f6b5b8e 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -158,7 +158,9 @@ class DatasetFixtureMixin : public ::testing::Test { /// record batches yielded by the data fragment. void AssertFragmentEquals(RecordBatchReader* expected, Fragment* fragment, bool ensure_drained = true) { - ASSERT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(options_)); + ASSERT_OK_AND_ASSIGN( + auto batch_gen, + fragment->ScanBatchesAsync(options_, ::arrow::internal::GetCpuThreadPool())); AssertScanTaskEquals(expected, batch_gen); if (ensure_drained) { @@ -581,7 +583,9 @@ class FileFormatScanMixin : public FileFormatFixtureMixin, // Scan the fragment directly, without using the scanner. RecordBatchIterator PhysicalBatches(std::shared_ptr fragment) { opts_->use_threads = GetParam().use_threads; - EXPECT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_)); + EXPECT_OK_AND_ASSIGN( + auto batch_gen, + fragment->ScanBatchesAsync(opts_, ::arrow::internal::GetCpuThreadPool())); auto batch_it = MakeGeneratorIterator(std::move(batch_gen)); return batch_it; } @@ -880,7 +884,8 @@ class DummyFileFormat : public FileFormat { /// \brief Open a file for scanning (always returns an empty generator) Result ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& fragment) const override { + const std::shared_ptr& fragment, + ::arrow::internal::Executor* cpu_executor) const override { return MakeEmptyGenerator>(); } @@ -920,7 +925,8 @@ class JSONRecordBatchFileFormat : public FileFormat { Result ScanBatchesAsync( const std::shared_ptr& options, - const std::shared_ptr& fragment) const override { + const std::shared_ptr& fragment, + ::arrow::internal::Executor* cpu_executor) const override { ARROW_ASSIGN_OR_RAISE(auto file, fragment->source().Open()); ARROW_ASSIGN_OR_RAISE(int64_t size, file->GetSize()); ARROW_ASSIGN_OR_RAISE(auto buffer, file->Read(size)); diff --git a/cpp/src/arrow/testing/async_test_util.h b/cpp/src/arrow/testing/async_test_util.h index b9f5487ed0d..acd19481369 100644 --- a/cpp/src/arrow/testing/async_test_util.h +++ b/cpp/src/arrow/testing/async_test_util.h @@ -20,12 +20,37 @@ #include #include +#include "arrow/testing/gtest_util.h" #include "arrow/util/async_generator.h" #include "arrow/util/future.h" namespace arrow { namespace util { +template +AsyncGenerator AsyncVectorIt(std::vector v) { + return MakeVectorGenerator(std::move(v)); +} + +template +AsyncGenerator FailsAt(AsyncGenerator src, int failing_index) { + auto index = std::make_shared>(0); + return [src, index, failing_index]() { + auto idx = index->fetch_add(1); + if (idx >= failing_index) { + return Future::MakeFinished(Status::Invalid("XYZ")); + } + return src(); + }; +} + +template +AsyncGenerator SlowdownABit(AsyncGenerator source) { + return MakeMappedGenerator(std::move(source), [](const T& res) { + return SleepABitAsync().Then([res]() { return res; }); + }); +} + template class TrackingGenerator { public: diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index 1983819f445..8662c319c99 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -79,6 +79,7 @@ add_arrow_test(threading-utility-test counting_semaphore_test.cc future_test.cc task_group_test.cc + test_common.cc thread_pool_test.cc) add_arrow_benchmark(bit_block_counter_benchmark) diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 0948e5537fe..d850014d0b9 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -60,6 +60,9 @@ namespace arrow { // Readahead operators, and some other operators, may introduce queueing. Any operators // that introduce buffering should detail the amount of buffering they introduce in their // MakeXYZ function comments. +// +// A generator should always be fully consumed before it is destroyed. +// A generator should not emit a terminal item until it has finished all ongoing futures. template using AsyncGenerator = std::function()>; @@ -750,19 +753,32 @@ class ReadaheadGenerator { Future AddMarkFinishedContinuation(Future fut) { auto state = state_; return fut.Then( - [state](const T& result) -> Result { + [state](const T& result) -> Future { state->MarkFinishedIfDone(result); + if (state->finished.load()) { + if (state->num_running.fetch_sub(1) == 1) { + state->final_future.MarkFinished(); + } + } else { + state->num_running.fetch_sub(1); + } return result; }, - [state](const Status& err) -> Result { + [state](const Status& err) -> Future { + // If there is an error we need to make sure all running + // tasks finish before we return the error. state->finished.store(true); - return err; + if (state->num_running.fetch_sub(1) == 1) { + state->final_future.MarkFinished(); + } + return state->final_future.Then([err]() -> Result { return err; }); }); } Future operator()() { if (state_->readahead_queue.empty()) { // This is the first request, let's pump the underlying queue + state_->num_running.store(state_->max_readahead); for (int i = 0; i < state_->max_readahead; i++) { auto next = state_->source_generator(); auto next_after_check = AddMarkFinishedContinuation(std::move(next)); @@ -775,6 +791,7 @@ class ReadaheadGenerator { if (state_->finished.load()) { state_->readahead_queue.push(AsyncGeneratorEnd()); } else { + state_->num_running.fetch_add(1); auto back_of_queue = state_->source_generator(); auto back_of_queue_after_check = AddMarkFinishedContinuation(std::move(back_of_queue)); @@ -786,9 +803,7 @@ class ReadaheadGenerator { private: struct State { State(AsyncGenerator source_generator, int max_readahead) - : source_generator(std::move(source_generator)), max_readahead(max_readahead) { - finished.store(false); - } + : source_generator(std::move(source_generator)), max_readahead(max_readahead) {} void MarkFinishedIfDone(const T& next_result) { if (IsIterationEnd(next_result)) { @@ -798,7 +813,9 @@ class ReadaheadGenerator { AsyncGenerator source_generator; int max_readahead; - std::atomic finished; + Future<> final_future = Future<>::Make(); + std::atomic num_running{0}; + std::atomic finished{false}; std::queue> readahead_queue; }; diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 7e5fccd9ef1..f1e3edf37cf 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -37,30 +37,6 @@ namespace arrow { -template -AsyncGenerator AsyncVectorIt(std::vector v) { - return MakeVectorGenerator(std::move(v)); -} - -template -AsyncGenerator FailsAt(AsyncGenerator src, int failing_index) { - auto index = std::make_shared>(0); - return [src, index, failing_index]() { - auto idx = index->fetch_add(1); - if (idx >= failing_index) { - return Future::MakeFinished(Status::Invalid("XYZ")); - } - return src(); - }; -} - -template -AsyncGenerator SlowdownABit(AsyncGenerator source) { - return MakeMappedGenerator(std::move(source), [](const T& res) { - return SleepABitAsync().Then([res]() { return res; }); - }); -} - template AsyncGenerator MakeJittery(AsyncGenerator source) { auto latency_generator = arrow::io::LatencyGenerator::Make(0.01); @@ -217,9 +193,9 @@ class GeneratorTestFixture : public ::testing::TestWithParam { protected: AsyncGenerator MakeSource(const std::vector& items) { std::vector wrapped(items.begin(), items.end()); - auto gen = AsyncVectorIt(std::move(wrapped)); + auto gen = util::AsyncVectorIt(std::move(wrapped)); if (IsSlow()) { - return SlowdownABit(std::move(gen)); + return util::SlowdownABit(std::move(gen)); } return gen; } @@ -231,7 +207,7 @@ class GeneratorTestFixture : public ::testing::TestWithParam { return Future::MakeFinished(Status::Invalid("XYZ")); }; if (IsSlow()) { - return SlowdownABit(std::move(gen)); + return util::SlowdownABit(std::move(gen)); } return gen; } @@ -324,7 +300,7 @@ class ManualGenerator { }; TEST(TestAsyncUtil, Visit) { - auto generator = AsyncVectorIt({1, 2, 3}); + auto generator = util::AsyncVectorIt({1, 2, 3}); unsigned int sum = 0; auto sum_future = VisitAsyncGenerator(generator, [&sum](TestInt item) { sum += item.value; @@ -336,7 +312,7 @@ TEST(TestAsyncUtil, Visit) { TEST(TestAsyncUtil, Collect) { std::vector expected = {1, 2, 3}; - auto generator = AsyncVectorIt(expected); + auto generator = util::AsyncVectorIt(expected); auto collected = CollectAsyncGenerator(generator); ASSERT_FINISHES_OK_AND_ASSIGN(auto collected_val, collected); ASSERT_EQ(expected, collected_val); @@ -344,7 +320,7 @@ TEST(TestAsyncUtil, Collect) { TEST(TestAsyncUtil, Map) { std::vector input = {1, 2, 3}; - auto generator = AsyncVectorIt(input); + auto generator = util::AsyncVectorIt(input); std::function mapper = [](const TestInt& in) { return std::to_string(in.value); }; @@ -355,7 +331,7 @@ TEST(TestAsyncUtil, Map) { TEST(TestAsyncUtil, MapAsync) { std::vector input = {1, 2, 3}; - auto generator = AsyncVectorIt(input); + auto generator = util::AsyncVectorIt(input); std::function(const TestInt&)> mapper = [](const TestInt& in) { return SleepAsync(1e-3).Then([in]() { return TestStr(std::to_string(in.value)); }); }; @@ -366,7 +342,7 @@ TEST(TestAsyncUtil, MapAsync) { TEST(TestAsyncUtil, MapReentrant) { std::vector input = {1, 2}; - auto source = AsyncVectorIt(input); + auto source = util::AsyncVectorIt(input); util::TrackingGenerator tracker(std::move(source)); source = MakeTransferredGenerator(AsyncGenerator(tracker), internal::GetCpuThreadPool()); @@ -408,7 +384,7 @@ TEST(TestAsyncUtil, MapParallelStress) { constexpr int NITEMS = 10; for (int i = 0; i < NTASKS; i++) { auto gen = MakeVectorGenerator(RangeVector(NITEMS)); - gen = SlowdownABit(std::move(gen)); + gen = util::SlowdownABit(std::move(gen)); auto guard = ExpectNotAccessedReentrantly(&gen); std::function mapper = [](const TestInt& in) { SleepABit(); @@ -427,9 +403,9 @@ TEST(TestAsyncUtil, MapQueuingFailStress) { for (bool slow : {true, false}) { for (int i = 0; i < NTASKS; i++) { std::shared_ptr> done = std::make_shared>(); - auto inner = AsyncVectorIt(RangeVector(NITEMS)); + auto inner = util::AsyncVectorIt(RangeVector(NITEMS)); if (slow) inner = MakeJittery(inner); - auto gen = FailsAt(inner, NITEMS / 2); + auto gen = util::FailsAt(inner, NITEMS / 2); std::function mapper = [done](const TestInt& in) { if (done->load()) { ADD_FAILURE() << "Callback called after generator sent end signal"; @@ -446,7 +422,7 @@ TEST(TestAsyncUtil, MapQueuingFailStress) { TEST(TestAsyncUtil, MapTaskFail) { std::vector input = {1, 2, 3}; - auto generator = AsyncVectorIt(input); + auto generator = util::AsyncVectorIt(input); std::function(const TestInt&)> mapper = [](const TestInt& in) -> Result { if (in.value == 2) { @@ -492,7 +468,7 @@ TEST(TestAsyncUtil, MapTaskDelayedFail) { TEST(TestAsyncUtil, MapSourceFail) { std::vector input = {1, 2, 3}; - auto generator = FailsAt(AsyncVectorIt(input), 1); + auto generator = util::FailsAt(util::AsyncVectorIt(input), 1); std::function(const TestInt&)> mapper = [](const TestInt& in) -> Result { return TestStr(std::to_string(in.value)); @@ -505,8 +481,8 @@ TEST(TestAsyncUtil, Concatenated) { std::vector inputOne{1, 2, 3}; std::vector inputTwo{4, 5, 6}; std::vector expected{1, 2, 3, 4, 5, 6}; - auto gen = AsyncVectorIt>( - {AsyncVectorIt(inputOne), AsyncVectorIt(inputTwo)}); + auto gen = util::AsyncVectorIt>( + {util::AsyncVectorIt(inputOne), util::AsyncVectorIt(inputTwo)}); auto concat = MakeConcatenatedGenerator(gen); AssertAsyncGeneratorMatch(expected, concat); } @@ -523,7 +499,7 @@ TEST_P(FromFutureFixture, Basic) { auto to_gen = source.Then([slow](const std::vector& vec) { auto vec_gen = MakeVectorGenerator(vec); if (slow) { - return SlowdownABit(std::move(vec_gen)); + return util::SlowdownABit(std::move(vec_gen)); } return vec_gen; }); @@ -538,7 +514,7 @@ INSTANTIATE_TEST_SUITE_P(FromFutureTests, FromFutureFixture, class MergedGeneratorTestFixture : public GeneratorTestFixture {}; TEST_P(MergedGeneratorTestFixture, Merged) { - auto gen = AsyncVectorIt>( + auto gen = util::AsyncVectorIt>( {MakeSource({1, 2, 3}), MakeSource({4, 5, 6})}); auto concat_gen = MakeMergedGenerator(gen, 10); @@ -552,23 +528,23 @@ TEST_P(MergedGeneratorTestFixture, Merged) { } TEST_P(MergedGeneratorTestFixture, MergedInnerFail) { - auto gen = AsyncVectorIt>( + auto gen = util::AsyncVectorIt>( {MakeSource({1, 2, 3}), MakeFailingSource()}); auto merged_gen = MakeMergedGenerator(gen, 10); ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen)); } TEST_P(MergedGeneratorTestFixture, MergedOuterFail) { - auto gen = - FailsAt(AsyncVectorIt>( - {MakeSource({1, 2, 3}), MakeSource({1, 2, 3}), MakeSource({1, 2, 3})}), - 1); + auto gen = util::FailsAt( + util::AsyncVectorIt>( + {MakeSource({1, 2, 3}), MakeSource({1, 2, 3}), MakeSource({1, 2, 3})}), + 1); auto merged_gen = MakeMergedGenerator(gen, 10); ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen)); } TEST_P(MergedGeneratorTestFixture, MergedLimitedSubscriptions) { - auto gen = AsyncVectorIt>( + auto gen = util::AsyncVectorIt>( {MakeSource({1, 2}), MakeSource({3, 4}), MakeSource({5, 6, 7, 8}), MakeSource({9, 10, 11, 12})}); util::TrackingGenerator> tracker(std::move(gen)); @@ -619,7 +595,7 @@ TEST_P(MergedGeneratorTestFixture, MergedStress) { guards.push_back(ExpectNotAccessedReentrantly(&source)); sources.push_back(source); } - AsyncGenerator> source_gen = AsyncVectorIt(sources); + AsyncGenerator> source_gen = util::AsyncVectorIt(sources); auto outer_gaurd = ExpectNotAccessedReentrantly(&source_gen); auto merged = MakeMergedGenerator(source_gen, 4); @@ -636,7 +612,7 @@ TEST_P(MergedGeneratorTestFixture, MergedParallelStress) { for (int j = 0; j < NGENERATORS; j++) { sources.push_back(MakeSource(RangeVector(NITEMS))); } - auto merged = MakeMergedGenerator(AsyncVectorIt(sources), 4); + auto merged = MakeMergedGenerator(util::AsyncVectorIt(sources), 4); merged = MakeReadaheadGenerator(merged, 4); ASSERT_FINISHES_OK_AND_ASSIGN(auto items, CollectAsyncGenerator(merged)); ASSERT_EQ(NITEMS * NGENERATORS, items.size()); @@ -1311,7 +1287,7 @@ TEST(TestAsyncUtil, Readahead) { } TEST(TestAsyncUtil, ReadaheadCopy) { - auto source = AsyncVectorIt(RangeVector(6)); + auto source = util::AsyncVectorIt(RangeVector(6)); auto gen = MakeReadaheadGenerator(std::move(source), 2); for (int i = 0; i < 2; i++) { @@ -1329,7 +1305,7 @@ TEST(TestAsyncUtil, ReadaheadCopy) { } TEST(TestAsyncUtil, ReadaheadMove) { - auto source = AsyncVectorIt(RangeVector(6)); + auto source = util::AsyncVectorIt(RangeVector(6)); auto gen = MakeReadaheadGenerator(std::move(source), 2); for (int i = 0; i < 2; i++) { @@ -1376,6 +1352,59 @@ TEST(TestAsyncUtil, ReadaheadFailed) { } } +TEST(TestAsyncUtil, ReadaheadFailedWaitForInFlight) { + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(20)); + // If a failure causes an early end then we should not emit that failure + // until all in-flight futures have completed. This is to prevent tasks from + // outliving the generator + std::atomic counter(0); + auto failure_gating_task = GatingTask::Make(); + auto in_flight_gating_task = GatingTask::Make(); + auto source = [&]() -> Future { + auto count = counter++; + return DeferNotOk(thread_pool->Submit([&, count]() -> Result { + if (count == 0) { + failure_gating_task->Task()(); + return Status::Invalid("X"); + } + in_flight_gating_task->Task()(); + // These are our in-flight tasks + return TestInt(0); + })); + }; + auto readahead = MakeReadaheadGenerator(source, 10); + auto should_be_invalid = readahead(); + ASSERT_OK(in_flight_gating_task->WaitForRunning(10)); + ASSERT_OK(failure_gating_task->Unlock()); + SleepABit(); + // Can't be finished because in-flight tasks are still running + AssertNotFinished(should_be_invalid); + ASSERT_OK(in_flight_gating_task->Unlock()); +} + +TEST(TestAsyncUtil, ReadaheadFailedStress) { + constexpr int NTASKS = 10; + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(20)); + for (int i = 0; i < NTASKS; i++) { + std::atomic counter(0); + std::atomic finished(false); + AsyncGenerator source = [&]() -> Future { + auto count = counter++; + return DeferNotOk(thread_pool->Submit([&, count]() -> Result { + SleepABit(); + if (count == 5) { + return Status::Invalid("X"); + } + // Generator should not have been finished at this point + EXPECT_FALSE(finished); + return TestInt(0); + })); + }; + ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(source)); + finished.store(false); + } +} + class EnumeratorTestFixture : public GeneratorTestFixture { protected: void AssertEnumeratedCorrectly(AsyncGenerator>& gen, @@ -1409,7 +1438,7 @@ TEST_P(EnumeratorTestFixture, Empty) { } TEST_P(EnumeratorTestFixture, Error) { - auto source = FailsAt(MakeSource({1, 2, 3}), 1); + auto source = util::FailsAt(MakeSource({1, 2, 3}), 1); auto enumerated = MakeEnumeratedGenerator(std::move(source)); // Even though the first item finishes ok the enumerator buffers it. The error then @@ -1475,7 +1504,7 @@ class PauseableTestFixture : public GeneratorTestFixture { AsyncGenerator GetSource() { const auto& source = static_cast>(generator_); if (IsSlow()) { - return SlowdownABit(source); + return util::SlowdownABit(source); } else { return source; } @@ -1550,7 +1579,7 @@ TEST_P(SequencerTestFixture, SequenceLambda) { TEST_P(SequencerTestFixture, SequenceError) { { auto original = MakeSource({6, 4, 2}); - original = FailsAt(original, 1); + original = util::FailsAt(original, 1); auto sequenced = MakeSequencingGenerator(original, cmp_, is_next_, TestInt(0)); auto collected = CollectAsyncGenerator(sequenced); ASSERT_FINISHES_AND_RAISES(Invalid, collected); @@ -1633,7 +1662,7 @@ INSTANTIATE_TEST_SUITE_P(SequencerTests, SequencerTestFixture, ::testing::Values(false, true)); TEST(TestAsyncIteratorTransform, SkipSome) { - auto original = AsyncVectorIt({1, 2, 3}); + auto original = util::AsyncVectorIt({1, 2, 3}); auto filter = MakeFilter([](TestInt& t) { return t.value != 2; }); auto filtered = MakeTransformedGenerator(std::move(original), filter); AssertAsyncGeneratorMatch({"1", "3"}, std::move(filtered)); diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index a1387947e3a..0538da98e3e 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -33,8 +33,6 @@ namespace arrow { namespace internal { -Executor::~Executor() = default; - namespace { struct Task { @@ -49,12 +47,28 @@ struct SerialExecutor::State { std::deque task_queue; std::mutex mutex; std::condition_variable wait_for_tasks; + bool paused{false}; bool finished{false}; }; SerialExecutor::SerialExecutor() : state_(std::make_shared()) {} -SerialExecutor::~SerialExecutor() = default; +SerialExecutor::~SerialExecutor() { + auto state = state_; + std::unique_lock lk(state->mutex); + if (!state->task_queue.empty()) { + // We may have remaining tasks if the executor is being abandoned. In general + // we could have resource leakage in this case. However, we can make a best effort + // and clean up running work. + state->paused = false; + lk.unlock(); + RunLoop(); + lk.lock(); + } + // Just in case some I/O thread comes in and tries to schedule somethign as we are + // tearing down. + state_->finished = true; +} Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce task, StopToken stop_token, StopCallback&& stop_callback) { @@ -68,6 +82,11 @@ Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce task, auto state = state_; { std::lock_guard lk(state->mutex); + if (state_->finished) { + return Status::Invalid( + "Attempt to scheduled a task on a serial executor that has already finished or " + "been abandoned"); + } state->task_queue.push_back( Task{std::move(task), std::move(stop_token), std::move(stop_callback)}); } @@ -75,8 +94,17 @@ Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce task, return Status::OK(); } -void SerialExecutor::MarkFinished() { +void SerialExecutor::Pause() { // Same comment as SpawnReal above + auto state = state_; + { + std::lock_guard lk(state->mutex); + state->paused = true; + } + state->wait_for_tasks.notify_one(); +} + +void SerialExecutor::Finish() { auto state = state_; { std::lock_guard lk(state->mutex); @@ -85,13 +113,30 @@ void SerialExecutor::MarkFinished() { state->wait_for_tasks.notify_one(); } +bool SerialExecutor::IsFinished() { + std::lock_guard lk(state_->mutex); + return state_->finished; +} + +void SerialExecutor::Unpause() { + auto state = state_; + { + std::lock_guard lk(state->mutex); + state->paused = false; + } +} + void SerialExecutor::RunLoop() { // This is called from the SerialExecutor's main thread, so the // state is guaranteed to be kept alive. std::unique_lock lk(state_->mutex); - while (!state_->finished) { - while (!state_->task_queue.empty()) { + while (!state_->paused && !(state_->finished && state_->task_queue.empty())) { + // The inner loop is to check if we need to sleep (e.g. while waiting on some + // async task to finish from another thread pool). We still need to check paused + // because sometimes we will pause even with work leftover when processing + // an async generator + while (!state_->paused && !state_->task_queue.empty()) { Task task = std::move(state_->task_queue.front()); state_->task_queue.pop_front(); lk.unlock(); @@ -108,8 +153,9 @@ void SerialExecutor::RunLoop() { } // In this case we must be waiting on work from external (e.g. I/O) executors. Wait // for tasks to arrive (typically via transferred futures). - state_->wait_for_tasks.wait( - lk, [&] { return state_->finished || !state_->task_queue.empty(); }); + state_->wait_for_tasks.wait(lk, [&] { + return state_->paused || state_->finished || !state_->task_queue.empty(); + }); } } diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index a104e0e3590..689d07d0006 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -36,6 +36,7 @@ #include "arrow/util/cancel.h" #include "arrow/util/functional.h" #include "arrow/util/future.h" +#include "arrow/util/iterator.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -82,7 +83,7 @@ class ARROW_EXPORT Executor { public: using StopCallback = internal::FnOnce; - virtual ~Executor(); + virtual ~Executor() = default; // Spawn a fire-and-forget task. template @@ -276,6 +277,73 @@ class ARROW_EXPORT SerialExecutor : public Executor { return FutureToSync(fut); } + template + static Iterator RunGeneratorInSerialExecutor( + internal::FnOnce()>>(Executor*)> initial_task) { + auto serial_executor = std::unique_ptr(new SerialExecutor()); + auto maybe_generator = std::move(initial_task)(serial_executor.get()); + if (!maybe_generator.ok()) { + return MakeErrorIterator(maybe_generator.status()); + } + auto generator = maybe_generator.MoveValueUnsafe(); + struct SerialIterator { + SerialIterator(std::unique_ptr executor, + std::function()> generator) + : executor(std::move(executor)), generator(std::move(generator)) {} + ARROW_DISALLOW_COPY_AND_ASSIGN(SerialIterator); + ARROW_DEFAULT_MOVE_AND_ASSIGN(SerialIterator); + ~SerialIterator() { + // A serial iterator must be consumed before it can be destroyed. Allowing it to + // do otherwise would lead to resource leakage. There will likely be deadlocks at + // this spot in the future but these will be the result of other bugs and not the + // fact that we are forcing consumption here. + + // If a streaming API needs to support early abandonment then it should be done so + // with a cancellation token and not simply discarding the iterator and expecting + // the underlying work to clean up correctly. + if (executor && !executor->IsFinished()) { + while (true) { + Result maybe_next = Next(); + if (!maybe_next.ok() || IsIterationEnd(*maybe_next)) { + break; + } + } + } + } + + Result Next() { + executor->Unpause(); + // This call will probably lead to a bunch of tasks being + // scheduled in the serial executor + Future next_fut = generator(); + next_fut.AddCallback([this](const Result& res) { + // If we're done iterating we should drain the rest of the tasks in the executor + if (!res.ok() || IsIterationEnd(*res)) { + executor->Finish(); + return; + } + // Otherwise we will break out immediately, leaving the remaining tasks for + // the next call. + executor->Pause(); + }); + // Borrow this thread and run tasks until the future is finished + executor->RunLoop(); + if (!next_fut.is_finished()) { + // Not clear this is possible + return Status::Invalid( + "Serial executor terminated before next result computed"); + } + // At this point we may still have tasks in the executor, that is ok. + // We will run those tasks the next time through. + return next_fut.result(); + } + + std::unique_ptr executor; + std::function()> generator; + }; + return Iterator(SerialIterator{std::move(serial_executor), std::move(generator)}); + } + private: SerialExecutor(); @@ -283,18 +351,19 @@ class ARROW_EXPORT SerialExecutor : public Executor { struct State; std::shared_ptr state_; + void RunLoop(); + void Finish(); + bool IsFinished(); + void Pause(); + void Unpause(); + template ::SyncType> Future Run(TopLevelTask initial_task) { auto final_fut = std::move(initial_task)(this); - if (final_fut.is_finished()) { - return final_fut; - } - final_fut.AddCallback([this](const FTSync&) { MarkFinished(); }); + final_fut.AddCallback([this](const FTSync&) { Finish(); }); RunLoop(); return final_fut; } - void RunLoop(); - void MarkFinished(); }; /// An Executor implementation spawning tasks in FIFO manner on a fixed-size diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index 56efedca74a..832effadd87 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -32,6 +32,7 @@ #include #include "arrow/status.h" +#include "arrow/testing/async_test_util.h" #include "arrow/testing/executor_util.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" @@ -260,6 +261,139 @@ TEST_P(TestRunSynchronously, PropagatedError) { INSTANTIATE_TEST_SUITE_P(TestRunSynchronously, TestRunSynchronously, ::testing::Values(false, true)); +TEST(SerialExecutor, AsyncGenerator) { + std::vector values{1, 2, 3, 4, 5}; + auto source = util::SlowdownABit(util::AsyncVectorIt(values)); + Iterator iter = SerialExecutor::RunGeneratorInSerialExecutor( + [&source](Executor* executor) { + return MakeMappedGenerator(source, [executor](const TestInt& ti) { + return DeferNotOk(executor->Submit([ti] { return ti; })); + }); + }); + ASSERT_OK_AND_ASSIGN(auto vec, iter.ToVector()); + ASSERT_EQ(vec, values); +} + +TEST(SerialExecutor, AsyncGeneratorWithFollowUp) { + // Sometimes a task will generate follow-up tasks. These must be run + // before the next task is started + bool follow_up_ran = false; + bool first = true; + Iterator iter = + SerialExecutor::RunGeneratorInSerialExecutor([&](Executor* executor) { + return [=, &first, &follow_up_ran]() -> Future { + if (first) { + first = false; + Future end = DeferNotOk(executor->Submit([] { return TestInt(0); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return end; + } + return DeferNotOk(executor->Submit([] { return IterationEnd(); })); + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(TestInt(0), iter.Next()); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(IterationEnd(), iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, AsyncGeneratorWithAsyncFollowUp) { + // Simulates a situation where a user calls into the async generator, tasks (e.g. I/O + // readahead tasks) are spawned onto the I/O threadpool, the user gets a result, and + // then the I/O readahead tasks are completed while there is no calling thread in the + // async generator to hand the task off to (it should be queued up) + bool follow_up_ran = false; + bool first = true; + Executor* captured_executor; + Iterator iter = + SerialExecutor::RunGeneratorInSerialExecutor([&](Executor* executor) { + return [=, &first, &captured_executor]() -> Future { + if (first) { + captured_executor = executor; + first = false; + return DeferNotOk(executor->Submit([] { + // I/O tasks would be scheduled at this point + return TestInt(0); + })); + } + return DeferNotOk(executor->Submit([] { return IterationEnd(); })); + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(TestInt(0), iter.Next()); + // I/O task completes and has reference to executor to submit continuation + ASSERT_OK(captured_executor->Spawn([&] { follow_up_ran = true; })); + // Follow-up task can't run right now because there is no thread in the executor + SleepABit(); + ASSERT_FALSE(follow_up_ran); + // Follow-up should run as part of retrieving the next item + ASSERT_OK_AND_EQ(IterationEnd(), iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, AsyncGeneratorWithCleanup) { + // Sometimes a final task might generate follow-up tasks. Unlike other follow-up + // tasks these must run before we finish the iterator. + bool follow_up_ran = false; + Iterator iter = + SerialExecutor::RunGeneratorInSerialExecutor([&](Executor* executor) { + return [=, &follow_up_ran]() -> Future { + Future end = + DeferNotOk(executor->Submit([] { return IterationEnd(); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return end; + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(IterationEnd(), iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, AbandonIteratorWithCleanup) { + // If we abandon an iterator we still need to drain all remaining tasks + bool follow_up_ran = false; + bool first = true; + { + Iterator iter = + SerialExecutor::RunGeneratorInSerialExecutor([&](Executor* executor) { + return [=, &first, &follow_up_ran]() -> Future { + if (first) { + first = false; + Future end = + DeferNotOk(executor->Submit([] { return TestInt(0); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return end; + } + return DeferNotOk(executor->Submit([] { return IterationEnd(); })); + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_OK_AND_EQ(TestInt(0), iter.Next()); + // At this point the iterator still has one remaining cleanup task + ASSERT_FALSE(follow_up_ran); + } + ASSERT_TRUE(follow_up_ran); +} + +TEST(SerialExecutor, FailingIteratorWithCleanup) { + // If an iterator hits an error we should still generally run any remaining tasks as + // they might be cleanup tasks. + bool follow_up_ran = false; + Iterator iter = + SerialExecutor::RunGeneratorInSerialExecutor([&](Executor* executor) { + return [=, &follow_up_ran]() -> Future { + Future end = DeferNotOk(executor->Submit( + []() -> Result { return Status::Invalid("XYZ"); })); + RETURN_NOT_OK(executor->Spawn([&] { follow_up_ran = true; })); + return end; + }; + }); + ASSERT_FALSE(follow_up_ran); + ASSERT_RAISES(Invalid, iter.Next()); + ASSERT_TRUE(follow_up_ran); +} + class TransferTest : public testing::Test { public: internal::Executor* executor() { return mock_executor.get(); } diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 4689543b5b8..ddfb7ef539e 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1700,6 +1700,10 @@ ipc___RecordBatchStreamWriter__Open <- function(stream, schema, use_legacy_forma .Call(`_arrow_ipc___RecordBatchStreamWriter__Open`, stream, schema, use_legacy_format, metadata_version) } +TestSafeCallIntoR <- function() { + .Call(`_arrow_TestSafeCallIntoR`) +} + Array__GetScalar <- function(x, i) { .Call(`_arrow_Array__GetScalar`, x, i) } diff --git a/r/R/config.R b/r/R/config.R index af07ad9a9d2..499e1038b15 100644 --- a/r/R/config.R +++ b/r/R/config.R @@ -33,6 +33,8 @@ set_cpu_count <- function(num_threads) { #' #' @export io_thread_count <- function() { + # I couldn't figure out how to expose a function properly so I'm just hijacking io_thread_count + TestSafeCallIntoR() GetIOThreadPoolCapacity() } diff --git a/r/src/Makevars.in b/r/src/Makevars.in index 7e20f9561ed..99c536e79a8 100644 --- a/r/src/Makevars.in +++ b/r/src/Makevars.in @@ -19,7 +19,7 @@ # PKG_CPPFLAGS=@cflags@ -I../inst/include/ # when it is not: -PKG_CPPFLAGS=@cflags@ +PKG_CPPFLAGS=-g @cflags@ # `-fvisibility=hidden` does not play well with UBSAN: # https://bugs.llvm.org/show_bug.cgi?id=39191 diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index eb77fdc40d1..163e37b8b19 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -6693,6 +6693,20 @@ extern "C" SEXP _arrow_ipc___RecordBatchStreamWriter__Open(SEXP stream_sexp, SEX } #endif +// safe-call-into-r.cpp +#if defined(ARROW_R_WITH_ARROW) +cpp11::strings TestSafeCallIntoR(); +extern "C" SEXP _arrow_TestSafeCallIntoR(){ +BEGIN_CPP11 + return cpp11::as_sexp(TestSafeCallIntoR()); +END_CPP11 +} +#else +extern "C" SEXP _arrow_TestSafeCallIntoR(){ + Rf_error("Cannot call TestSafeCallIntoR(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // scalar.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr Array__GetScalar(const std::shared_ptr& x, int64_t i); @@ -7523,32 +7537,6 @@ extern "C" SEXP _arrow_Array__infer_type(SEXP x_sexp){ } #endif -#if defined(ARROW_R_WITH_ARROW) -extern "C" SEXP _arrow_Table__Reset(SEXP r6) { -BEGIN_CPP11 -arrow::r::r6_reset_pointer(r6); -END_CPP11 -return R_NilValue; -} -#else -extern "C" SEXP _arrow_Table__Reset(SEXP r6){ - Rf_error("Cannot call Table(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -#if defined(ARROW_R_WITH_ARROW) -extern "C" SEXP _arrow_RecordBatch__Reset(SEXP r6) { -BEGIN_CPP11 -arrow::r::r6_reset_pointer(r6); -END_CPP11 -return R_NilValue; -} -#else -extern "C" SEXP _arrow_RecordBatch__Reset(SEXP r6){ - Rf_error("Cannot call RecordBatch(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - extern "C" SEXP _arrow_available() { return Rf_ScalarLogical( #if defined(ARROW_R_WITH_ARROW) @@ -7595,11 +7583,11 @@ return Rf_ScalarLogical( ); } static const R_CallMethodDef CallEntries[] = { - { "_arrow_available", (DL_FUNC)& _arrow_available, 0 }, - { "_dataset_available", (DL_FUNC)& _dataset_available, 0 }, - { "_parquet_available", (DL_FUNC)& _parquet_available, 0 }, - { "_s3_available", (DL_FUNC)& _s3_available, 0 }, - { "_json_available", (DL_FUNC)& _json_available, 0 }, +{ "_arrow_available", (DL_FUNC)& _arrow_available, 0 }, +{ "_dataset_available", (DL_FUNC)& _dataset_available, 0 }, +{ "_parquet_available", (DL_FUNC)& _parquet_available, 0 }, +{ "_s3_available", (DL_FUNC)& _s3_available, 0 }, +{ "_json_available", (DL_FUNC)& _json_available, 0 }, { "_arrow_test_SET_STRING_ELT", (DL_FUNC) &_arrow_test_SET_STRING_ELT, 1}, { "_arrow_is_arrow_altrep", (DL_FUNC) &_arrow_is_arrow_altrep, 1}, { "_arrow_Array__Slice1", (DL_FUNC) &_arrow_Array__Slice1, 2}, @@ -8025,6 +8013,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchWriter__Close", (DL_FUNC) &_arrow_ipc___RecordBatchWriter__Close, 1}, { "_arrow_ipc___RecordBatchFileWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchFileWriter__Open, 4}, { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 4}, + { "_arrow_TestSafeCallIntoR", (DL_FUNC) &_arrow_TestSafeCallIntoR, 0}, { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, { "_arrow_StructScalar__field", (DL_FUNC) &_arrow_StructScalar__field, 2}, @@ -8078,8 +8067,6 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_GetIOThreadPoolCapacity", (DL_FUNC) &_arrow_GetIOThreadPoolCapacity, 0}, { "_arrow_SetIOThreadPoolCapacity", (DL_FUNC) &_arrow_SetIOThreadPoolCapacity, 1}, { "_arrow_Array__infer_type", (DL_FUNC) &_arrow_Array__infer_type, 1}, - { "_arrow_Table__Reset", (DL_FUNC) &_arrow_Table__Reset, 1}, - { "_arrow_RecordBatch__Reset", (DL_FUNC) &_arrow_RecordBatch__Reset, 1}, {NULL, NULL, 0} }; extern "C" void R_init_arrow(DllInfo* dll){ diff --git a/r/src/safe-call-into-r.cpp b/r/src/safe-call-into-r.cpp new file mode 100644 index 00000000000..35174b2f705 --- /dev/null +++ b/r/src/safe-call-into-r.cpp @@ -0,0 +1,37 @@ +#include "safe-call-into-r.h" + +#include + +// [[arrow::export]] +cpp11::strings TestSafeCallIntoR() { + // This simulates the Arrow thread pool. Just imagine it is static and lives forever. + std::thread* thread_ptr; + + // Pretend we are in parquet___arrow___FileReader__ReadTable1, you wouldn't need to + // capture thread_ptr in the real world. + arrow::Result run_res = RunWithCapturedR([&thread_ptr]() { + // We wouldn't Make a future here, we would call ReadTableAsync and + // return that future + arrow::Future fut = arrow::Future::Make(); + thread_ptr = new std::thread([fut]() mutable { + // At this point we are deep in the bowels of parquet reading and need to issue a + // call to the filesystem. So we are in your filesystem adapter here needing to + // call into R + SafeCallIntoR([] { + // The body of this task runs on the R thread, you can do R thread stuff here. + cpp11::function gc = cpp11::package("base")["gc"]; + gc(); + return 0; + // In reality you wouldn't have to worry about the AddCallback because you would + // be implementing something like ReadAtAsync and you could just return a future. + }).AddCallback([fut](const arrow::Result&) mutable { fut.MarkFinished(0); }); + }); + return fut; + }); + // Ignore everything below this point + thread_ptr->join(); + delete thread_ptr; + + cpp11::writable::strings results_sexp; + return results_sexp; +} diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h new file mode 100644 index 00000000000..7f8a9258640 --- /dev/null +++ b/r/src/safe-call-into-r.h @@ -0,0 +1,65 @@ +#ifndef SAFE_CALL_INTO_R_INCLUDED +#define SAFE_CALL_INTO_R_INCLUDED + +#include +#include "./arrow_types.h" + +#include + +arrow::internal::Executor*& GetRExecutor() { + static arrow::internal::Executor* r_executor = nullptr; + return r_executor; +} + +/** + * This is called at the top level, from R, when calling a function that might make + * callbacks back into R. For example: + * + * std::shared_ptr parquet___arrow___FileReader__ReadTable1( + * const std::shared_ptr& reader) { + * return RunWithCapturedR([&reader] { + * return reader->ReadTableAsync()); + * }); + * } + * + * Note: ReadTableAsync does not exist today. This would be the first case of exposing + * C++ async functionality. It wouldn't be too hard to create but you can maybe see + * that this is opening a slight can of worms. + * + * Unfortunately, if we don't call the async version, we have no way in the C++ code to + * setup an event loop. + */ +template +arrow::Result RunWithCapturedR(std::function()> task) { + if (GetRExecutor() != nullptr) { + // Error here that RunWithCapturedR called reentrantly which is a no-go + } + arrow::Result cmd_result = arrow::internal::SerialExecutor::RunInSerialExecutor( + [task](arrow::internal::Executor* executor) { + GetRExecutor() = executor; + arrow::Future result = task(); + return result; + }); + GetRExecutor() = nullptr; +} + +/** + * This is called from an Arrow context when you need to call back into R. For example, + * you would call this from your filesystem adapter when you need to make R calls to get + * the data. + * + * Note that this function returns a Future. That's because we are going to + * "pause/yield" this Arrow thread while we wait for the R thread to do its thing. + * + * Fortunately, the filesystem API already has support for Future-returning API methods. + * You can then simulate the sync version by returning (e.g.) ReadAtAsync().result(); + */ +template +arrow::Future SafeCallIntoR(std::function task) { + if (GetRExecutor() == nullptr) { + // Error here that we are not in a RunWithCapturedR context + } + return arrow::DeferNotOk((GetRExecutor())->Submit(task)); +} + +#endif