Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

namespace arrow {
namespace compute {
namespace aggregate {

namespace {

void AggregateConsume(KernelContext* ctx, const ExecBatch& batch) {
checked_cast<ScalarAggregator*>(ctx->state())->Consume(ctx, batch);
Expand All @@ -38,6 +39,19 @@ void AggregateFinalize(KernelContext* ctx, Datum* out) {
checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, out);
}

} // namespace

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge,
AggregateFinalize);
// Set the simd level
kernel.simd_level = simd_level;
DCHECK_OK(func->AddKernel(kernel));
}

namespace aggregate {

// ----------------------------------------------------------------------
// Count implementation

Expand Down Expand Up @@ -137,15 +151,6 @@ std::unique_ptr<KernelState> MinMaxInit(KernelContext* ctx, const KernelInitArgs
return visitor.Create();
}

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge,
AggregateFinalize);
// Set the simd level
kernel.simd_level = simd_level;
DCHECK_OK(func->AddKernel(kernel));
}

void AddBasicAggKernels(KernelInit init,
const std::vector<std::shared_ptr<DataType>>& types,
std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
Expand Down Expand Up @@ -202,8 +207,8 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {

// Takes any array input, outputs int64 scalar
InputType any_array(ValueDescr::ARRAY);
aggregate::AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())),
aggregate::CountInit, func.get());
AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())),
aggregate::CountInit, func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));

func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), &sum_doc);
Expand Down Expand Up @@ -263,10 +268,6 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
#endif

DCHECK_OK(registry->AddFunction(std::move(func)));

DCHECK_OK(registry->AddFunction(aggregate::AddModeAggKernels()));
DCHECK_OK(registry->AddFunction(aggregate::AddStddevAggKernels()));
DCHECK_OK(registry->AddFunction(aggregate::AddVarianceAggKernels()));
}

} // namespace internal
Expand Down
14 changes: 0 additions & 14 deletions cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,6 @@ namespace arrow {
namespace compute {
namespace aggregate {

struct ScalarAggregator : public KernelState {
virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0;
virtual void MergeFrom(KernelContext* ctx, KernelState&& src) = 0;
virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
};

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);

void AddBasicAggKernels(KernelInit init,
const std::vector<std::shared_ptr<DataType>>& types,
std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
Expand All @@ -58,10 +48,6 @@ void AddSumAvx512AggKernels(ScalarAggregateFunction* func);
void AddMeanAvx512AggKernels(ScalarAggregateFunction* func);
void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func);

std::shared_ptr<ScalarAggregateFunction> AddModeAggKernels();
std::shared_ptr<ScalarAggregateFunction> AddStddevAggKernels();
std::shared_ptr<ScalarAggregateFunction> AddVarianceAggKernels();

// ----------------------------------------------------------------------
// Sum implementation

Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,15 @@ struct FindAccumulatorType<I, enable_if_floating_point<I>> {
using Type = DoubleType;
};

struct ScalarAggregator : public KernelState {
virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0;
virtual void MergeFrom(KernelContext* ctx, KernelState&& src) = 0;
virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
};

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);

} // namespace compute
} // namespace arrow
18 changes: 12 additions & 6 deletions cpp/src/arrow/compute/kernels/aggregate_mode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
#include <cmath>
#include <unordered_map>

#include "arrow/compute/kernels/aggregate_basic_internal.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/common.h"

namespace arrow {
namespace compute {
namespace aggregate {
namespace internal {

namespace {

Expand Down Expand Up @@ -277,16 +279,20 @@ const FunctionDoc mode_doc{
"null is returned."),
{"array"}};

} // namespace

std::shared_ptr<ScalarAggregateFunction> AddModeAggKernels() {
auto func =
std::make_shared<ScalarAggregateFunction>("mode", Arity::Unary(), &mode_doc);
AddModeKernels(ModeInit, {boolean()}, func.get());
AddModeKernels(ModeInit, internal::NumericTypes(), func.get());
AddModeKernels(ModeInit, NumericTypes(), func.get());
return func;
}

} // namespace aggregate
} // namespace

void RegisterScalarAggregateMode(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(AddModeAggKernels()));
}

} // namespace internal
} // namespace compute
} // namespace arrow
23 changes: 16 additions & 7 deletions cpp/src/arrow/compute/kernels/aggregate_var_std.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/kernels/aggregate_basic_internal.h"
#include <cmath>

#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/util/int128_internal.h"

namespace arrow {
namespace compute {
namespace aggregate {
namespace internal {

namespace {

Expand Down Expand Up @@ -252,24 +256,29 @@ const FunctionDoc variance_doc{
{"array"},
"VarianceOptions"};

} // namespace

std::shared_ptr<ScalarAggregateFunction> AddStddevAggKernels() {
static auto default_std_options = VarianceOptions::Defaults();
auto func = std::make_shared<ScalarAggregateFunction>(
"stddev", Arity::Unary(), &stddev_doc, &default_std_options);
AddVarStdKernels(StddevInit, internal::NumericTypes(), func.get());
AddVarStdKernels(StddevInit, NumericTypes(), func.get());
return func;
}

std::shared_ptr<ScalarAggregateFunction> AddVarianceAggKernels() {
static auto default_var_options = VarianceOptions::Defaults();
auto func = std::make_shared<ScalarAggregateFunction>(
"variance", Arity::Unary(), &variance_doc, &default_var_options);
AddVarStdKernels(VarianceInit, internal::NumericTypes(), func.get());
AddVarStdKernels(VarianceInit, NumericTypes(), func.get());
return func;
}

} // namespace aggregate
} // namespace

void RegisterScalarAggregateVariance(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(AddVarianceAggKernels()));
DCHECK_OK(registry->AddFunction(AddStddevAggKernels()));
}

} // namespace internal
} // namespace compute
} // namespace arrow
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {

// Aggregate functions
RegisterScalarAggregateBasic(registry.get());
RegisterScalarAggregateMode(registry.get());
RegisterScalarAggregateVariance(registry.get());

// Vector functions
RegisterVectorHash(registry.get());
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/registry_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ void RegisterVectorSort(FunctionRegistry* registry);

// Aggregate functions
void RegisterScalarAggregateBasic(FunctionRegistry* registry);
void RegisterScalarAggregateMode(FunctionRegistry* registry);
void RegisterScalarAggregateVariance(FunctionRegistry* registry);

} // namespace internal
} // namespace compute
Expand Down