diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 11fb91965fec9..76cd8feac8194 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -88,6 +88,7 @@ import com.facebook.presto.operator.aggregation.ReduceAggregationFunction; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; import com.facebook.presto.operator.aggregation.VarianceAggregation; +import com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent; import com.facebook.presto.operator.aggregation.arrayagg.ArrayAggregationFunction; import com.facebook.presto.operator.aggregation.arrayagg.SetAggregationFunction; import com.facebook.presto.operator.aggregation.arrayagg.SetUnionFunction; @@ -353,7 +354,6 @@ import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT_AND_COMPRESSION; -import static com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent.APPROXIMATE_MOST_FREQUENT; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMaxByAggregationFunction.ALTERNATIVE_MAX_BY; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMinByAggregationFunction.ALTERNATIVE_MIN_BY; import static com.facebook.presto.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY; @@ -937,7 +937,7 @@ private List getBuildInFunctions(FeaturesConfig featuresC .functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION) .functions(MAP_TRANSFORM_KEY_FUNCTION, MAP_TRANSFORM_VALUE_FUNCTION) .function(TRY_CAST) - .function(APPROXIMATE_MOST_FREQUENT) + .aggregate(ApproximateMostFrequent.class) .function(K_DISTINCT) .aggregate(MergeSetDigestAggregation.class) .aggregate(BuildSetDigestAggregation.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java index 759af63a810bd..bea895b312144 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java @@ -13,45 +13,24 @@ */ package com.facebook.presto.operator.aggregation.approxmostfrequent; -import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; -import com.facebook.presto.common.type.ArrayType; -import com.facebook.presto.common.type.BigintType; -import com.facebook.presto.common.type.NamedTypeSignature; -import com.facebook.presto.common.type.RowFieldName; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.TypeSignatureParameter; -import com.facebook.presto.metadata.BoundVariables; -import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.SqlAggregationFunction; -import com.facebook.presto.operator.aggregation.AccumulatorCompiler; -import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; import com.facebook.presto.operator.aggregation.approxmostfrequent.stream.StreamSummary; -import com.facebook.presto.spi.function.aggregation.Accumulator; -import com.facebook.presto.spi.function.aggregation.AggregationMetadata; -import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; -import com.google.common.collect.ImmutableList; - -import java.lang.invoke.MethodHandle; -import java.util.List; -import java.util.Optional; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.BlockIndex; +import com.facebook.presto.spi.function.BlockPosition; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; import static com.facebook.presto.common.type.StandardTypes.BIGINT; -import static com.facebook.presto.common.type.StandardTypes.MAP; -import static com.facebook.presto.common.type.StandardTypes.ROW; -import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static com.facebook.presto.spi.function.Signature.comparableTypeParameter; -import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; -import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; -import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL; -import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.util.Failures.checkCondition; -import static com.facebook.presto.util.Reflection.methodHandle; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Math.toIntExact; /** @@ -66,110 +45,36 @@ * by Ahmed Metwally, Divyakant Agrawal, and Amr El Abbadi *

*/ +@AggregationFunction(value = "approx_most_frequent") +@Description("Computes the top frequent elements approximately") public final class ApproximateMostFrequent - extends SqlAggregationFunction { - public static final ApproximateMostFrequent APPROXIMATE_MOST_FREQUENT = new ApproximateMostFrequent(); - public static final String NAME = "approx_most_frequent"; - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "output", ApproximateMostFrequentState.class, BlockBuilder.class); - private static final MethodHandle INPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "input", Type.class, ApproximateMostFrequentState.class, long.class, Block.class, int.class, long.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(ApproximateMostFrequent.class, "combine", ApproximateMostFrequentState.class, ApproximateMostFrequentState.class); - private static final String MAX_BUCKETS = "max_buckets"; - private static final String CAPACITY = "capacity"; - private static final String KEYS = "keys"; - private static final String VALUES = "values"; - - protected ApproximateMostFrequent() - { - super(NAME, - ImmutableList.of(comparableTypeParameter("K")), - ImmutableList.of(), - parseTypeSignature("map(K,bigint)"), - ImmutableList.of(parseTypeSignature(BIGINT), parseTypeSignature("K"), parseTypeSignature(BIGINT))); - } - - @Override - public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) - { - Type keyType = boundVariables.getTypeVariable("K"); - checkArgument(keyType.isComparable(), "keyType must be comparable"); - Type serializedType = functionAndTypeManager.getParameterizedType(ROW, ImmutableList.of( - buildTypeSignatureParameter(MAX_BUCKETS, BigintType.BIGINT), - buildTypeSignatureParameter(CAPACITY, BigintType.BIGINT), - buildTypeSignatureParameter(KEYS, new ArrayType(keyType)), - buildTypeSignatureParameter(VALUES, new ArrayType(BigintType.BIGINT)))); - Type outputType = functionAndTypeManager.getParameterizedType(MAP, ImmutableList.of( - TypeSignatureParameter.of(keyType.getTypeSignature()), - TypeSignatureParameter.of(BigintType.BIGINT.getTypeSignature()))); - - DynamicClassLoader classLoader = new DynamicClassLoader(ApproximateMostFrequent.class.getClassLoader()); - List inputTypes = ImmutableList.of(keyType); - ApproximateMostFrequentStateSerializer stateSerializer = new ApproximateMostFrequentStateSerializer(keyType, serializedType); - MethodHandle inputFunction = INPUT_FUNCTION.bindTo(keyType); - - AggregationMetadata metadata = new AggregationMetadata( - generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), - createInputParameterMetadata(keyType), - inputFunction, - COMBINE_FUNCTION, - OUTPUT_FUNCTION, - ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor( - ApproximateMostFrequentState.class, - stateSerializer, - new ApproximateMostFrequentStateFactory())), - outputType); - - Class accumulatorClass = AccumulatorCompiler.generateAccumulatorClass( - Accumulator.class, - metadata, - classLoader); - Class groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass( - GroupedAccumulator.class, - metadata, - classLoader); - return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(serializedType), outputType, - true, false, metadata, accumulatorClass, groupedAccumulatorClass); - } - - private TypeSignatureParameter buildTypeSignatureParameter(String fieldName, Type type) - { - return TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName(fieldName, false)), type.getTypeSignature())); - } - - private static List createInputParameterMetadata(Type keyType) - { - return ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE), - new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, BigintType.BIGINT), - new AggregationMetadata.ParameterMetadata(BLOCK_INPUT_CHANNEL, keyType), - new AggregationMetadata.ParameterMetadata(BLOCK_INDEX), - new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, BigintType.BIGINT)); - } - - @Override - public String getDescription() - { - return "Computes the top frequent elements approximately"; - } + private ApproximateMostFrequent() + {} - public static void input(Type type, - ApproximateMostFrequentState state, - long buckets, - Block valueBlock, - int valueIndex, long capacity) + @InputFunction + @TypeParameter("T") + public static void input( + @TypeParameter("T") Type type, + @AggregationState ApproximateMostFrequentState state, + @SqlType(BIGINT) long buckets, + @BlockPosition @SqlType("T") Block valueBlock, + @BlockIndex int valueIndex, + @SqlType(BIGINT) long capacity) { StreamSummary streamSummary = state.getStateSummary(); if (streamSummary == null) { checkCondition(buckets > 1, INVALID_FUNCTION_ARGUMENT, "approx_most_frequent bucket count must be greater than one, input bucket count: %s", buckets); - streamSummary = new StreamSummary( - type, - toIntExact(buckets), - toIntExact(capacity)); + streamSummary = new StreamSummary(type, toIntExact(buckets), toIntExact(capacity)); state.setStateSummary(streamSummary); } streamSummary.add(valueBlock, valueIndex, 1L); } - public static void combine(ApproximateMostFrequentState state, ApproximateMostFrequentState otherState) + @CombineFunction + public static void combine( + @AggregationState ApproximateMostFrequentState state, + @AggregationState ApproximateMostFrequentState otherState) { StreamSummary streamSummary = state.getStateSummary(); if (streamSummary == null) { @@ -180,7 +85,8 @@ public static void combine(ApproximateMostFrequentState state, ApproximateMostFr } } - public static void output(ApproximateMostFrequentState state, BlockBuilder out) + @OutputFunction("map(T,bigint)") + public static void output(@AggregationState ApproximateMostFrequentState state, BlockBuilder out) { if (state.getStateSummary() == null) { out.appendNull(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java index e1cca1926bde1..0ec5200d440cf 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java @@ -15,20 +15,32 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.aggregation.approxmostfrequent.stream.StreamSummary; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.function.TypeParameter; +import com.google.common.collect.ImmutableList; public class ApproximateMostFrequentStateSerializer implements AccumulatorStateSerializer { private final Type type; private final Type serializedType; + private static final String MAX_BUCKETS = "max_buckets"; + private static final String CAPACITY = "capacity"; + private static final String KEYS = "keys"; + private static final String VALUES = "values"; - public ApproximateMostFrequentStateSerializer(Type type, Type serializedType) + public ApproximateMostFrequentStateSerializer(@TypeParameter("T") Type type) { this.type = type; - this.serializedType = serializedType; + this.serializedType = RowType.from(ImmutableList.of(RowType.field(MAX_BUCKETS, BigintType.BIGINT), + RowType.field(CAPACITY, BigintType.BIGINT), + RowType.field(KEYS, new ArrayType(type)), + RowType.field(VALUES, new ArrayType(BigintType.BIGINT)))); } @Override diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index eee827b375c8c..25d20db2d3f93 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -5979,6 +5979,18 @@ public void testApproxMostFrequentWithStringGroupBy() assertEquals(actual1.getMaterializedRows().get(2).getFields().get(1), ImmutableMap.of("C", 2L)); } + @Test + public void testApproxMostFrequentWithNullGroupBy() + { + MaterializedResult actual1 = computeActual("select k, approx_most_frequent(2, v, 10) from (values (1, null), (2, 3)) t(k, v) group by k order by k"); + + assertEquals(actual1.getRowCount(), 2); + assertEquals(actual1.getMaterializedRows().get(0).getFields().get(0), 1); + assertEquals(actual1.getMaterializedRows().get(0).getFields().get(1), null); + assertEquals(actual1.getMaterializedRows().get(1).getFields().get(0), 2); + assertEquals(actual1.getMaterializedRows().get(1).getFields().get(1), ImmutableMap.of(3, 1L)); + } + @Test public void testUnknownMaxBy() {