diff --git a/velox/exec/Aggregate.cpp b/velox/exec/Aggregate.cpp index 1bab06a1a83..350d1b1e9a4 100644 --- a/velox/exec/Aggregate.cpp +++ b/velox/exec/Aggregate.cpp @@ -79,23 +79,29 @@ AggregateRegistrationResult registerAggregateFunction( registered.mainFunction = inserted; } - // Register the aggregate as a window function also. - registerAggregateWindowFunction(sanitizedName); + // If the aggregate is not a companion function, also register it as a window + // function. + if (!metadata.companionFunction) { + registerAggregateWindowFunction(sanitizedName); + } // Register companion function if needed. if (registerCompanionFunctions) { + auto companionMetadata = metadata; + companionMetadata.companionFunction = true; + registered.partialFunction = CompanionFunctionsRegistrar::registerPartialFunction( - name, signatures, overwrite); + name, signatures, companionMetadata, overwrite); registered.mergeFunction = CompanionFunctionsRegistrar::registerMergeFunction( - name, signatures, overwrite); + name, signatures, companionMetadata, overwrite); registered.extractFunction = CompanionFunctionsRegistrar::registerExtractFunction( name, signatures, overwrite); registered.mergeExtractFunction = CompanionFunctionsRegistrar::registerMergeExtractFunction( - name, signatures, overwrite); + name, signatures, companionMetadata, overwrite); } return registered; } @@ -141,6 +147,15 @@ std::vector registerAggregateFunction( return registrationResults; } +const AggregateFunctionMetadata& getAggregateFunctionMetadata( + const std::string& name) { + const auto sanitizedName = sanitizeName(name); + if (auto func = getAggregateFunctionEntry(sanitizedName)) { + return func->metadata; + } + VELOX_USER_FAIL("Aggregate function not found: {}", name); +} + std::unordered_map< std::string, std::vector>> diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index d6bc12aefcd..36e49623cfb 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -476,6 +476,9 @@ struct AggregateFunctionMetadata { /// True if results of the aggregation depend on the order of inputs. For /// example, array_agg is order sensitive while count is not. bool orderSensitive{true}; + + /// Indicates if this is a companion function. + bool companionFunction{false}; }; /// Register an aggregate function with the specified name and signatures. If /// registerCompanionFunctions is true, also register companion aggregate and @@ -514,6 +517,9 @@ std::vector registerAggregateFunction( bool registerCompanionFunctions, bool overwrite); +const AggregateFunctionMetadata& getAggregateFunctionMetadata( + const std::string& name); + /// Returns signatures of the aggregate function with the specified name. /// Returns empty std::optional if function with that name is not found. std::optional>> diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 1c443511f85..ea3c64c5c9e 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -249,6 +249,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply( bool CompanionFunctionsRegistrar::registerPartialFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { auto partialSignatures = CompanionSignatures::partialFunctionSignatures(signatures); @@ -280,6 +281,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction( name, CompanionSignatures::partialFunctionName(name)); }, + metadata, /*registerCompanionFunctions*/ false, overwrite) .mainFunction; @@ -288,6 +290,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction( bool CompanionFunctionsRegistrar::registerMergeFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { auto mergeSignatures = CompanionSignatures::mergeFunctionSignatures(signatures); @@ -320,16 +323,18 @@ bool CompanionFunctionsRegistrar::registerMergeFunction( name, CompanionSignatures::mergeFunctionName(name)); }, + metadata, /*registerCompanionFunctions*/ false, overwrite) .mainFunction; } -bool registerAggregateFunction( +bool registerMergeExtractFunctionInternal( const std::string& name, const std::string& mergeExtractFunctionName, const std::vector>& mergeExtractSignatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { return exec::registerAggregateFunction( mergeExtractFunctionName, @@ -365,6 +370,7 @@ bool registerAggregateFunction( name, mergeExtractFunctionName); }, + metadata, /*registerCompanionFunctions*/ false, overwrite) .mainFunction; @@ -373,6 +379,7 @@ bool registerAggregateFunction( bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { auto groupedSignatures = CompanionSignatures::groupSignaturesByReturnType(signatures); @@ -387,10 +394,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( auto mergeExtractFunctionName = CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type); - registered |= registerAggregateFunction( + registered |= registerMergeExtractFunctionInternal( name, mergeExtractFunctionName, std::move(mergeExtractSignatures), + metadata, overwrite); } return registered; @@ -399,10 +407,12 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( bool CompanionFunctionsRegistrar::registerMergeExtractFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( signatures)) { - return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite); + return registerMergeExtractFunctionWithSuffix( + name, signatures, metadata, overwrite); } auto mergeExtractSignatures = @@ -413,10 +423,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction( auto mergeExtractFunctionName = CompanionSignatures::mergeExtractFunctionName(name); - return registerAggregateFunction( + return registerMergeExtractFunctionInternal( name, mergeExtractFunctionName, std::move(mergeExtractSignatures), + metadata, overwrite); } @@ -475,6 +486,7 @@ bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix( std::move(factory), exec::VectorFunctionMetadataBuilder() .defaultNullBehavior(false) + .companionFunction(true) .build(), overwrite); } @@ -502,7 +514,10 @@ bool CompanionFunctionsRegistrar::registerExtractFunction( CompanionSignatures::extractFunctionName(originalName), std::move(extractSignatures), std::move(factory), - exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(), + exec::VectorFunctionMetadataBuilder() + .defaultNullBehavior(false) + .companionFunction(true) + .build(), overwrite); } diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 8ebef51fcc4..8a6af66ff51 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -178,6 +178,7 @@ class CompanionFunctionsRegistrar { static bool registerPartialFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite = false); // When there is already a function of the same name as the merge companion @@ -186,6 +187,7 @@ class CompanionFunctionsRegistrar { static bool registerMergeFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite = false); // If there are multiple signatures of the original aggregation function @@ -213,6 +215,7 @@ class CompanionFunctionsRegistrar { static bool registerMergeExtractFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite = false); private: @@ -227,6 +230,7 @@ class CompanionFunctionsRegistrar { static bool registerMergeExtractFunctionWithSuffix( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite); }; diff --git a/velox/expression/FunctionMetadata.h b/velox/expression/FunctionMetadata.h index bc354110b95..d8d743f9b2f 100644 --- a/velox/expression/FunctionMetadata.h +++ b/velox/expression/FunctionMetadata.h @@ -40,6 +40,9 @@ struct VectorFunctionMetadata { /// In this case, 'rows' in VectorFunction::apply will point only to positions /// for which all arguments are not null. bool defaultNullBehavior{true}; + + /// Indicates if this is a companion function. + bool companionFunction{false}; }; class VectorFunctionMetadataBuilder { @@ -59,6 +62,11 @@ class VectorFunctionMetadataBuilder { return *this; } + VectorFunctionMetadataBuilder& companionFunction(bool companionFunction) { + metadata_.companionFunction = companionFunction; + return *this; + } + const VectorFunctionMetadata& build() const { return metadata_; } diff --git a/velox/functions/lib/aggregates/BitwiseAggregateBase.h b/velox/functions/lib/aggregates/BitwiseAggregateBase.h index cb9405d2c3d..972ae40c0ad 100644 --- a/velox/functions/lib/aggregates/BitwiseAggregateBase.h +++ b/velox/functions/lib/aggregates/BitwiseAggregateBase.h @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerBitwise( inputType->kindName()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/AverageAggregate.cpp b/velox/functions/prestosql/aggregates/AverageAggregate.cpp index 78aa7eabb58..bb98ce74695 100644 --- a/velox/functions/prestosql/aggregates/AverageAggregate.cpp +++ b/velox/functions/prestosql/aggregates/AverageAggregate.cpp @@ -155,7 +155,7 @@ void registerAverageAggregate( } } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/BoolAggregates.cpp b/velox/functions/prestosql/aggregates/BoolAggregates.cpp index 5095cb373cc..8fa78746550 100644 --- a/velox/functions/prestosql/aggregates/BoolAggregates.cpp +++ b/velox/functions/prestosql/aggregates/BoolAggregates.cpp @@ -209,7 +209,7 @@ exec::AggregateRegistrationResult registerBool( inputType->kindName()); return std::make_unique(); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp b/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp index 1b209bfa18d..de594177e49 100644 --- a/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp @@ -262,7 +262,7 @@ void registerChecksumAggregate( return std::make_unique(VARBINARY()); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/CountAggregate.cpp b/velox/functions/prestosql/aggregates/CountAggregate.cpp index 335d7cb06f4..9a1559e23e4 100644 --- a/velox/functions/prestosql/aggregates/CountAggregate.cpp +++ b/velox/functions/prestosql/aggregates/CountAggregate.cpp @@ -182,7 +182,7 @@ void registerCountAggregate( argTypes.size(), 1, "{} takes at most one argument", name); return std::make_unique(); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/CountIfAggregate.cpp b/velox/functions/prestosql/aggregates/CountIfAggregate.cpp index 5da65ec7ce4..2fcc238cc5c 100644 --- a/velox/functions/prestosql/aggregates/CountIfAggregate.cpp +++ b/velox/functions/prestosql/aggregates/CountIfAggregate.cpp @@ -206,7 +206,7 @@ void registerCountIfAggregate( return std::make_unique(); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp index 9bbe9ec8c73..7226eda2583 100644 --- a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp +++ b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp @@ -136,7 +136,7 @@ void registerGeometricMeanAggregate( inputType->toString()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/HistogramAggregate.cpp b/velox/functions/prestosql/aggregates/HistogramAggregate.cpp index c2481925b42..b920530c426 100644 --- a/velox/functions/prestosql/aggregates/HistogramAggregate.cpp +++ b/velox/functions/prestosql/aggregates/HistogramAggregate.cpp @@ -632,7 +632,7 @@ void registerHistogramAggregate( inputType->toString()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp index 871a62733d3..fd7684829ae 100644 --- a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -514,7 +514,7 @@ exec::AggregateRegistrationResult registerMinMax( } } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/ReduceAgg.cpp b/velox/functions/prestosql/aggregates/ReduceAgg.cpp index 7c7553ffd3e..370d013e425 100644 --- a/velox/functions/prestosql/aggregates/ReduceAgg.cpp +++ b/velox/functions/prestosql/aggregates/ReduceAgg.cpp @@ -817,7 +817,7 @@ void registerReduceAgg( const core::QueryConfig& config) -> std::unique_ptr { return std::make_unique(resultType); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/SumAggregate.cpp b/velox/functions/prestosql/aggregates/SumAggregate.cpp index 8fbbff7c603..919fcf8a303 100644 --- a/velox/functions/prestosql/aggregates/SumAggregate.cpp +++ b/velox/functions/prestosql/aggregates/SumAggregate.cpp @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerSum( inputType->kindName()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*companionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp b/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp index 40777e874ef..6862e3b8862 100644 --- a/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp @@ -75,26 +75,16 @@ TEST_F(AggregationFunctionRegTest, orderSensitive) { "histogram", "reduce_agg"}; aggregate::prestosql::registerAllAggregateFunctions(); - exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) { - for (const auto& entry : aggrFuncMap) { - if (!entry.second.metadata.orderSensitive) { - EXPECT_EQ(1, nonOrderSensitiveFunctions.erase(entry.first)); - } - } - }); - EXPECT_EQ(0, nonOrderSensitiveFunctions.size()); + for (const auto& entry : nonOrderSensitiveFunctions) { + ASSERT_FALSE(exec::getAggregateFunctionMetadata(entry).orderSensitive); + } // Test some but not all order sensitive functions std::set orderSensitiveFunctions = { "array_agg", "arbitrary", "any_value", "map_agg", "map_union", "set_agg"}; - exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) { - for (const auto& entry : aggrFuncMap) { - if (entry.second.metadata.orderSensitive) { - orderSensitiveFunctions.erase(entry.first); - } - } - }); - EXPECT_EQ(0, orderSensitiveFunctions.size()); + for (const auto& entry : orderSensitiveFunctions) { + ASSERT_TRUE(exec::getAggregateFunctionMetadata(entry).orderSensitive); + } } TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) { @@ -121,4 +111,22 @@ TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) { clearAndCheckRegistry(); } +TEST_F(AggregationFunctionRegTest, companionFunction) { + // Remove all functions and check for no entries. + clearAndCheckRegistry(); + + aggregate::prestosql::registerAllAggregateFunctions(); + const auto aggregates = {"approx_distinct", "count", "sum"}; + const auto companionFunctions = { + "approx_distinct_merge", "approx_distinct_partial"}; + + for (const auto& function : aggregates) { + ASSERT_FALSE( + exec::getAggregateFunctionMetadata(function).companionFunction); + } + for (const auto& function : companionFunctions) { + ASSERT_TRUE(exec::getAggregateFunctionMetadata(function).companionFunction); + } +} + } // namespace facebook::velox::aggregate::test diff --git a/velox/functions/tests/FunctionRegistryTest.cpp b/velox/functions/tests/FunctionRegistryTest.cpp index 0523af23a5d..e2a7869d928 100644 --- a/velox/functions/tests/FunctionRegistryTest.cpp +++ b/velox/functions/tests/FunctionRegistryTest.cpp @@ -24,6 +24,7 @@ #include "velox/functions/FunctionRegistry.h" #include "velox/functions/Macros.h" #include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/functions/prestosql/types/IPPrefixType.h" @@ -357,6 +358,26 @@ TEST_F(FunctionRegistryTest, isDeterministic) { ASSERT_FALSE(isDeterministic("not_found_function").has_value()); } +TEST_F(FunctionRegistryTest, companionFunction) { + functions::prestosql::registerAllScalarFunctions(); + aggregate::prestosql::registerAllAggregateFunctions(); + const auto functions = {"array_frequency", "bitwise_left_shift", "ceil"}; + // Aggregate companion functions with suffix '_extract' are registered as + // vector functions. + const auto companionFunctions = { + "array_agg_extract", "arbitrary_extract", "bitwise_and_agg_extract"}; + + for (const auto& function : functions) { + ASSERT_FALSE(exec::simpleFunctions() + .getFunctionSignaturesAndMetadata(function) + .front() + .first.companionFunction); + } + for (const auto& function : companionFunctions) { + ASSERT_TRUE(exec::getVectorFunctionMetadata(function)->companionFunction); + } +} + template struct TestFunction { VELOX_DEFINE_FUNCTION_TYPES(T);