diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 0d03f7340d1..11c1e2b1730 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -24,7 +24,8 @@ namespace arrow { namespace compute { -namespace aggregate { + +namespace { void AggregateConsume(KernelContext* ctx, const ExecBatch& batch) { checked_cast(ctx->state())->Consume(ctx, batch); @@ -38,6 +39,19 @@ void AggregateFinalize(KernelContext* ctx, Datum* out) { checked_cast(ctx->state())->Finalize(ctx, out); } +} // namespace + +void AddAggKernel(std::shared_ptr sig, KernelInit init, + ScalarAggregateFunction* func, SimdLevel::type simd_level) { + ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge, + AggregateFinalize); + // Set the simd level + kernel.simd_level = simd_level; + DCHECK_OK(func->AddKernel(kernel)); +} + +namespace aggregate { + // ---------------------------------------------------------------------- // Count implementation @@ -137,15 +151,6 @@ std::unique_ptr MinMaxInit(KernelContext* ctx, const KernelInitArgs return visitor.Create(); } -void AddAggKernel(std::shared_ptr sig, KernelInit init, - ScalarAggregateFunction* func, SimdLevel::type simd_level) { - ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge, - AggregateFinalize); - // Set the simd level - kernel.simd_level = simd_level; - DCHECK_OK(func->AddKernel(kernel)); -} - void AddBasicAggKernels(KernelInit init, const std::vector>& types, std::shared_ptr out_ty, ScalarAggregateFunction* func, @@ -202,8 +207,8 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { // Takes any array input, outputs int64 scalar InputType any_array(ValueDescr::ARRAY); - aggregate::AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())), - aggregate::CountInit, func.get()); + AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())), + aggregate::CountInit, func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared("sum", Arity::Unary(), &sum_doc); @@ -263,10 +268,6 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { #endif DCHECK_OK(registry->AddFunction(std::move(func))); - - DCHECK_OK(registry->AddFunction(aggregate::AddModeAggKernels())); - DCHECK_OK(registry->AddFunction(aggregate::AddStddevAggKernels())); - DCHECK_OK(registry->AddFunction(aggregate::AddVarianceAggKernels())); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index 2b0631ee2f2..733e6d1d0a6 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -29,16 +29,6 @@ namespace arrow { namespace compute { namespace aggregate { -struct ScalarAggregator : public KernelState { - virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0; - virtual void MergeFrom(KernelContext* ctx, KernelState&& src) = 0; - virtual void Finalize(KernelContext* ctx, Datum* out) = 0; -}; - -void AddAggKernel(std::shared_ptr sig, KernelInit init, - ScalarAggregateFunction* func, - SimdLevel::type simd_level = SimdLevel::NONE); - void AddBasicAggKernels(KernelInit init, const std::vector>& types, std::shared_ptr out_ty, ScalarAggregateFunction* func, @@ -58,10 +48,6 @@ void AddSumAvx512AggKernels(ScalarAggregateFunction* func); void AddMeanAvx512AggKernels(ScalarAggregateFunction* func); void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func); -std::shared_ptr AddModeAggKernels(); -std::shared_ptr AddStddevAggKernels(); -std::shared_ptr AddVarianceAggKernels(); - // ---------------------------------------------------------------------- // Sum implementation diff --git a/cpp/src/arrow/compute/kernels/aggregate_internal.h b/cpp/src/arrow/compute/kernels/aggregate_internal.h index 5f2f50c0b06..cb67794d942 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_internal.h @@ -47,5 +47,15 @@ struct FindAccumulatorType> { using Type = DoubleType; }; +struct ScalarAggregator : public KernelState { + virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0; + virtual void MergeFrom(KernelContext* ctx, KernelState&& src) = 0; + virtual void Finalize(KernelContext* ctx, Datum* out) = 0; +}; + +void AddAggKernel(std::shared_ptr sig, KernelInit init, + ScalarAggregateFunction* func, + SimdLevel::type simd_level = SimdLevel::NONE); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_mode.cc b/cpp/src/arrow/compute/kernels/aggregate_mode.cc index 352c592bb11..6544df549e6 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_mode.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_mode.cc @@ -18,11 +18,13 @@ #include #include -#include "arrow/compute/kernels/aggregate_basic_internal.h" +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" namespace arrow { namespace compute { -namespace aggregate { +namespace internal { namespace { @@ -277,16 +279,20 @@ const FunctionDoc mode_doc{ "null is returned."), {"array"}}; -} // namespace - std::shared_ptr AddModeAggKernels() { auto func = std::make_shared("mode", Arity::Unary(), &mode_doc); AddModeKernels(ModeInit, {boolean()}, func.get()); - AddModeKernels(ModeInit, internal::NumericTypes(), func.get()); + AddModeKernels(ModeInit, NumericTypes(), func.get()); return func; } -} // namespace aggregate +} // namespace + +void RegisterScalarAggregateMode(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunction(AddModeAggKernels())); +} + +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index 4bffd5a754b..4dac0a37734 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/kernels/aggregate_basic_internal.h" +#include + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" #include "arrow/util/int128_internal.h" namespace arrow { namespace compute { -namespace aggregate { +namespace internal { namespace { @@ -252,13 +256,11 @@ const FunctionDoc variance_doc{ {"array"}, "VarianceOptions"}; -} // namespace - std::shared_ptr AddStddevAggKernels() { static auto default_std_options = VarianceOptions::Defaults(); auto func = std::make_shared( "stddev", Arity::Unary(), &stddev_doc, &default_std_options); - AddVarStdKernels(StddevInit, internal::NumericTypes(), func.get()); + AddVarStdKernels(StddevInit, NumericTypes(), func.get()); return func; } @@ -266,10 +268,17 @@ std::shared_ptr AddVarianceAggKernels() { static auto default_var_options = VarianceOptions::Defaults(); auto func = std::make_shared( "variance", Arity::Unary(), &variance_doc, &default_var_options); - AddVarStdKernels(VarianceInit, internal::NumericTypes(), func.get()); + AddVarStdKernels(VarianceInit, NumericTypes(), func.get()); return func; } -} // namespace aggregate +} // namespace + +void RegisterScalarAggregateVariance(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunction(AddVarianceAggKernels())); + DCHECK_OK(registry->AddFunction(AddStddevAggKernels())); +} + +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 376e8206cc1..7ef1e26d59b 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -128,6 +128,8 @@ static std::unique_ptr CreateBuiltInRegistry() { // Aggregate functions RegisterScalarAggregateBasic(registry.get()); + RegisterScalarAggregateMode(registry.get()); + RegisterScalarAggregateVariance(registry.get()); // Vector functions RegisterVectorHash(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index d84f85cd153..78e134eb41f 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -43,6 +43,8 @@ void RegisterVectorSort(FunctionRegistry* registry); // Aggregate functions void RegisterScalarAggregateBasic(FunctionRegistry* registry); +void RegisterScalarAggregateMode(FunctionRegistry* registry); +void RegisterScalarAggregateVariance(FunctionRegistry* registry); } // namespace internal } // namespace compute