Skip to content
Closed
4 changes: 2 additions & 2 deletions cpp/examples/arrow/compute_register_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ class ExampleNode : public cp::ExecNode {
void ResumeProducing(ExecNode* output) override {}
void PauseProducing(ExecNode* output) override {}

void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); }
void StopProducing() override { inputs_[0]->StopProducing(); }
void StopProducing(ExecNode* output) override {}
void StopProducing() override {}

void InputReceived(ExecNode* input, cp::ExecBatch batch) override {}
void ErrorReceived(ExecNode* input, arrow::Status error) override {}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/exec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,8 @@ ExecContext::ExecContext(MemoryPool* pool, ::arrow::internal::Executor* executor
FunctionRegistry* func_registry)
: pool_(pool), executor_(executor) {
this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry;
DCHECK_NE(nullptr, executor_);
this->use_threads_ = executor_->GetCapacity() > 1;
}

CpuInfo* ExecContext::cpu_info() const { return CpuInfo::GetInstance(); }
Expand Down
8 changes: 5 additions & 3 deletions cpp/src/arrow/compute/exec.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "arrow/result.h"
#include "arrow/type_fwd.h"
#include "arrow/util/macros.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/type_fwd.h"
#include "arrow/util/visibility.h"

Expand All @@ -60,9 +61,10 @@ static constexpr int64_t kDefaultExecChunksize = UINT16_MAX;
class ARROW_EXPORT ExecContext {
public:
// If no function registry passed, the default is used.
explicit ExecContext(MemoryPool* pool = default_memory_pool(),
::arrow::internal::Executor* executor = NULLPTR,
FunctionRegistry* func_registry = NULLPTR);
explicit ExecContext(
MemoryPool* pool = default_memory_pool(),
::arrow::internal::Executor* executor = ::arrow::internal::GetCpuThreadPool(),
FunctionRegistry* func_registry = NULLPTR);

/// \brief The MemoryPool used for allocations, default is
/// default_memory_pool().
Expand Down
4 changes: 1 addition & 3 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ class ScalarAggregateNode : public ExecNode {
if (input_counter_.Cancel()) {
finished_.MarkFinished();
}
inputs_[0]->StopProducing(this);
}

Future<> finished() override { return finished_; }
Expand Down Expand Up @@ -531,7 +530,7 @@ class GroupByNode : public ExecNode {

auto executor = ctx_->executor();
for (int i = 0; i < num_output_batches; ++i) {
if (executor) {
if (ctx_->use_threads()) {
// bail if StopProducing was called
if (finished_.is_finished()) break;

Expand Down Expand Up @@ -610,7 +609,6 @@ class GroupByNode : public ExecNode {
if (output_counter_.Cancel()) {
finished_.MarkFinished();
}
inputs_[0]->StopProducing(this);
}

void StopProducing() override { StopProducing(outputs_[0]); }
Expand Down
36 changes: 16 additions & 20 deletions cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,16 +360,11 @@ bool ExecNode::ErrorIfNotOk(Status status) {
}

MapNode::MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema, bool async_mode)
std::shared_ptr<Schema> output_schema, bool use_threads)
: ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"},
std::move(output_schema),
/*num_outputs=*/1) {
if (async_mode) {
executor_ = plan_->exec_context()->executor();
} else {
executor_ = nullptr;
}
}
/*num_outputs=*/1),
use_threads_(use_threads) {}

void MapNode::ErrorReceived(ExecNode* input, Status error) {
DCHECK_EQ(input, inputs_[0]);
Expand Down Expand Up @@ -406,13 +401,14 @@ void MapNode::StopProducing(ExecNode* output) {

void MapNode::StopProducing() {
EVENT(span_, "StopProducing");
if (executor_) {
if (use_threads_) {
// If we are using tasks we may have a bunch of queued tasks that we should
// cancel
this->stop_source_.RequestStop();
}
if (input_counter_.Cancel()) {
this->Finish();
}
inputs_[0]->StopProducing(this);
}

Future<> MapNode::finished() { return finished_; }
Expand All @@ -436,15 +432,16 @@ void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
return Status::OK();
};

if (executor_) {
if (use_threads_) {
status = task_group_.AddTask([this, task]() -> Result<Future<>> {
return this->executor_->Submit(this->stop_source_.token(), [this, task]() {
auto status = task();
if (this->input_counter_.Increment()) {
this->Finish(status);
}
return status;
});
return this->plan()->exec_context()->executor()->Submit(
this->stop_source_.token(), [this, task]() {
auto status = task();
if (this->input_counter_.Increment()) {
this->Finish(status);
}
return status;
});
});
} else {
status = task();
Expand All @@ -458,13 +455,12 @@ void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
if (input_counter_.Cancel()) {
this->Finish(status);
}
inputs_[0]->StopProducing(this);
return;
}
}

void MapNode::Finish(Status finish_st /*= Status::OK()*/) {
if (executor_) {
if (use_threads_) {
task_group_.End().AddCallback([this, finish_st](const Status& st) {
Status final_status = finish_st & st;
this->finished_.MarkFinished(final_status);
Expand Down
7 changes: 5 additions & 2 deletions cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class ARROW_EXPORT ExecNode {
class MapNode : public ExecNode {
public:
MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema, bool async_mode);
std::shared_ptr<Schema> output_schema, bool use_threads);

void ErrorReceived(ExecNode* input, Status error) override;

Expand Down Expand Up @@ -303,7 +303,10 @@ class MapNode : public ExecNode {
// The task group for the corresponding batches
util::AsyncTaskGroup task_group_;

::arrow::internal::Executor* executor_;
// If true then tasks will be spawned for each item
//
// If false the item will be processed immediately and synchronously
bool use_threads_;

// Variable used to cancel remaining tasks in the executor
StopSource stop_source_;
Expand Down
41 changes: 26 additions & 15 deletions cpp/src/arrow/compute/exec/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,24 @@ BatchesWithSchema GenerateBatchesFromString(
return out_batches;
}

std::unique_ptr<ExecContext> SimpleExecContext(
bool parallel, std::shared_ptr<::arrow::internal::ThreadPool>* owned_thread_pool) {
if (parallel) {
return arrow::internal::make_unique<ExecContext>();
} else {
EXPECT_OK_AND_ASSIGN(*owned_thread_pool, ::arrow::internal::ThreadPool::Make(1));
return arrow::internal::make_unique<ExecContext>(default_memory_pool(),
owned_thread_pool->get());
}
}

void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches,
const BatchesWithSchema& r_batches,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& right_keys,
const BatchesWithSchema& exp_batches, bool parallel = false) {
auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool;
auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));

Expand Down Expand Up @@ -904,8 +915,8 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel,
const std::vector<std::shared_ptr<Array>>& l,
const std::vector<std::shared_ptr<Array>>& r, int num_batches_l,
int num_batches_r, std::shared_ptr<Table>* output) {
auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool;
auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));

Expand Down Expand Up @@ -1013,8 +1024,8 @@ TEST(HashJoin, Random) {
#endif
for (int test_id = 0; test_id < num_tests; ++test_id) {
bool parallel = (rng.from_range(0, 1) == 1);
auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool;
auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool);

// Constraints
RandomDataTypeConstraints type_constraints;
Expand Down Expand Up @@ -1297,8 +1308,8 @@ void TestHashJoinDictionaryHelper(
}
}

auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool;
auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool);
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
ASSERT_OK_AND_ASSIGN(
ExecNode * l_source,
Expand Down Expand Up @@ -1699,8 +1710,8 @@ TEST(HashJoin, DictNegative) {
ExecBatch::Make({i == 2 ? datumSecondB : datumSecondA,
i == 3 ? datumSecondB : datumSecondA}));

auto exec_ctx =
arrow::internal::make_unique<ExecContext>(default_memory_pool(), nullptr);
std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool;
auto exec_ctx = SimpleExecContext(/*parallel=*/false, &owned_thread_pool);
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
ASSERT_OK_AND_ASSIGN(
ExecNode * l_source,
Expand Down Expand Up @@ -1728,6 +1739,7 @@ TEST(HashJoin, DictNegative) {
EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Unifying differing dictionaries"),
StartAndCollect(plan.get(), sink_gen));
ASSERT_FINISHES_OK(plan->finished());
}
}

Expand Down Expand Up @@ -1786,8 +1798,8 @@ TEST(HashJoin, ResidualFilter) {
input_right.schema =
schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())});

auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool;
auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
Expand Down Expand Up @@ -1863,9 +1875,8 @@ TEST(HashJoin, TrivialResidualFilter) {
])")};
input_right.schema = schema({field("r1", int32()), field("r_str", utf8())});

auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(),
parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
std::shared_ptr<::arrow::internal::ThreadPool> owned_thread_pool;
auto exec_ctx = SimpleExecContext(parallel, &owned_thread_pool);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
Expand Down
44 changes: 25 additions & 19 deletions cpp/src/arrow/compute/exec/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,11 @@ TEST(ExecPlanExecution, SourceSinkError) {

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
Finishes(Raises(StatusCode::Invalid, HasSubstr("Artificial"))));
// Note: the plan may or may not be finished at this point. When an error
// hits the sink node it starts to mark itself finished but before that it emits
// the error to the producer which will cause the above wait to finish (possibly
// before the plan has marked itself finished). So we wait for the plan to finish.
ASSERT_FINISHES_OK(plan->finished());
}

TEST(ExecPlanExecution, SourceConsumingSink) {
Expand Down Expand Up @@ -527,7 +532,7 @@ TEST(ExecPlanExecution, SourceConsumingSink) {
}
}

TEST(ExecPlanExecution, SourceTableConsumingSink) {
TEST(ExecPlanExecution, SourceTableSink) {
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");

Expand All @@ -549,11 +554,12 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) {
ASSERT_OK(plan->StartProducing());
// Source should finish fairly quickly
ASSERT_FINISHES_OK(source->finished());
SleepABit();
ASSERT_OK_AND_ASSIGN(auto actual,
ASSERT_OK_AND_ASSIGN(auto expected,
TableFromExecBatches(basic_data.schema, basic_data.batches));
ASSERT_EQ(5, out->num_rows());
AssertTablesEqual(*actual, *out);
ASSERT_OK_AND_ASSIGN(auto expected_sorted, SortTableOnAllFields(expected));
ASSERT_OK_AND_ASSIGN(auto out_sorted, SortTableOnAllFields(out));
AssertTablesEqual(*expected_sorted, *out_sorted);
ASSERT_FINISHES_OK(plan->finished());
}
}
Expand Down Expand Up @@ -856,11 +862,11 @@ TEST(ExecPlanExecution, SourceGroupedSum) {
})
.AddToPlan(plan.get()));

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0),
Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON(
{int64(), utf8()},
parallel ? R"([[800, "alfa"], [1000, "beta"], [400, "gama"]])"
: R"([[8, "alfa"], [10, "beta"], [4, "gama"]])")}))));
parallel ? R"([[400, "gama"], [800, "alfa"], [1000, "beta"]])"
: R"([[4, "gama"], [8, "alfa"], [10, "beta"]])")}))));
}
}

Expand Down Expand Up @@ -900,9 +906,9 @@ TEST(ExecPlanExecution, NestedSourceProjectGroupedSum) {

auto input = MakeNestedBatches();
auto expected = ExecBatchFromJSON({int64(), boolean()}, R"([
[null, true],
[5, null],
[17, false],
[5, null]
[null, true]
])");

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
Expand All @@ -926,7 +932,7 @@ TEST(ExecPlanExecution, NestedSourceProjectGroupedSum) {
})
.AddToPlan(plan.get()));

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0),
Finishes(ResultWith(UnorderedElementsAreArray({expected}))));
}
}
Expand Down Expand Up @@ -962,10 +968,10 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumFilter) {
})
.AddToPlan(plan.get()));

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0),
Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON(
{int64(), utf8()}, parallel ? R"([[3600, "alfa"], [2000, "beta"]])"
: R"([[36, "alfa"], [20, "beta"]])")}))));
{int64(), utf8()}, parallel ? R"([[2000, "beta"], [3600, "alfa"]])"
: R"([[20, "beta"], [36, "alfa"]])")}))));
}
}

Expand Down Expand Up @@ -1123,7 +1129,7 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) {
.AddToPlan(plan.get()));
}

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 1),
Finishes(ResultWith(UnorderedElementsAreArray({
ExecBatchFromJSON({int64(), utf8()},
R"([[500, "alfa"], [200, "beta"], [200, "gama"]])"),
Expand Down Expand Up @@ -1211,7 +1217,7 @@ TEST(ExecPlanExecution, ScalarSourceGroupedSum) {
})
.AddToPlan(plan.get()));

ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
ASSERT_THAT(StartAndCollectSortedByField(plan.get(), sink_gen, 0),
Finishes(ResultWith(UnorderedElementsAreArray({
ExecBatchFromJSON({int64(), boolean()}, R"([[6, true], [18, false]])"),
}))));
Expand All @@ -1223,8 +1229,8 @@ TEST(ExecPlanExecution, SelfInnerHashJoinSink) {

auto input = MakeGroupableBatches();

auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
auto exec_ctx = arrow::internal::make_unique<ExecContext>(default_memory_pool());
exec_ctx->set_use_threads(parallel);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
Expand Down Expand Up @@ -1280,8 +1286,8 @@ TEST(ExecPlanExecution, SelfOuterHashJoinSink) {

auto input = MakeGroupableBatches();

auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
auto exec_ctx = arrow::internal::make_unique<ExecContext>(default_memory_pool());
exec_ctx->set_use_threads(parallel);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
Expand Down
Loading