diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 76cd8feac8194..11fb91965fec9 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -88,7 +88,6 @@ import com.facebook.presto.operator.aggregation.ReduceAggregationFunction; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; import com.facebook.presto.operator.aggregation.VarianceAggregation; -import com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent; import com.facebook.presto.operator.aggregation.arrayagg.ArrayAggregationFunction; import com.facebook.presto.operator.aggregation.arrayagg.SetAggregationFunction; import com.facebook.presto.operator.aggregation.arrayagg.SetUnionFunction; @@ -354,6 +353,7 @@ import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT_AND_COMPRESSION; +import static com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent.APPROXIMATE_MOST_FREQUENT; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMaxByAggregationFunction.ALTERNATIVE_MAX_BY; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMinByAggregationFunction.ALTERNATIVE_MIN_BY; import static com.facebook.presto.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY; @@ -937,7 +937,7 @@ private List getBuildInFunctions(FeaturesConfig featuresC .functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION) .functions(MAP_TRANSFORM_KEY_FUNCTION, MAP_TRANSFORM_VALUE_FUNCTION) .function(TRY_CAST) - .aggregate(ApproximateMostFrequent.class) + .function(APPROXIMATE_MOST_FREQUENT) .function(K_DISTINCT) .aggregate(MergeSetDigestAggregation.class) .aggregate(BuildSetDigestAggregation.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java index f762028e48ee5..759af63a810bd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java @@ -13,25 +13,45 @@ */ package com.facebook.presto.operator.aggregation.approxmostfrequent; +import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.NamedTypeSignature; +import com.facebook.presto.common.type.RowFieldName; import com.facebook.presto.common.type.Type; -import com.facebook.presto.operator.aggregation.NullablePosition; +import com.facebook.presto.common.type.TypeSignatureParameter; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.SqlAggregationFunction; +import com.facebook.presto.operator.aggregation.AccumulatorCompiler; +import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; import com.facebook.presto.operator.aggregation.approxmostfrequent.stream.StreamSummary; -import com.facebook.presto.spi.function.AggregationFunction; -import com.facebook.presto.spi.function.AggregationState; -import com.facebook.presto.spi.function.BlockIndex; -import com.facebook.presto.spi.function.BlockPosition; -import com.facebook.presto.spi.function.CombineFunction; -import com.facebook.presto.spi.function.Description; -import com.facebook.presto.spi.function.InputFunction; -import com.facebook.presto.spi.function.OutputFunction; -import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.function.aggregation.Accumulator; +import com.facebook.presto.spi.function.aggregation.AggregationMetadata; +import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; +import com.google.common.collect.ImmutableList; + +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.Optional; import static com.facebook.presto.common.type.StandardTypes.BIGINT; +import static com.facebook.presto.common.type.StandardTypes.MAP; +import static com.facebook.presto.common.type.StandardTypes.ROW; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.Signature.comparableTypeParameter; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.util.Failures.checkCondition; +import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Math.toIntExact; /** @@ -46,36 +66,110 @@ * by Ahmed Metwally, Divyakant Agrawal, and Amr El Abbadi *

*/ -@AggregationFunction(value = "approx_most_frequent", isCalledOnNullInput = true) -@Description("Computes the top frequent elements approximately") public final class ApproximateMostFrequent + extends SqlAggregationFunction { - private ApproximateMostFrequent() - {} + public static final ApproximateMostFrequent APPROXIMATE_MOST_FREQUENT = new ApproximateMostFrequent(); + public static final String NAME = "approx_most_frequent"; + private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "output", ApproximateMostFrequentState.class, BlockBuilder.class); + private static final MethodHandle INPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "input", Type.class, ApproximateMostFrequentState.class, long.class, Block.class, int.class, long.class); + private static final MethodHandle COMBINE_FUNCTION = methodHandle(ApproximateMostFrequent.class, "combine", ApproximateMostFrequentState.class, ApproximateMostFrequentState.class); + private static final String MAX_BUCKETS = "max_buckets"; + private static final String CAPACITY = "capacity"; + private static final String KEYS = "keys"; + private static final String VALUES = "values"; + + protected ApproximateMostFrequent() + { + super(NAME, + ImmutableList.of(comparableTypeParameter("K")), + ImmutableList.of(), + parseTypeSignature("map(K,bigint)"), + ImmutableList.of(parseTypeSignature(BIGINT), parseTypeSignature("K"), parseTypeSignature(BIGINT))); + } + + @Override + public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) + { + Type keyType = boundVariables.getTypeVariable("K"); + checkArgument(keyType.isComparable(), "keyType must be comparable"); + Type serializedType = functionAndTypeManager.getParameterizedType(ROW, ImmutableList.of( + buildTypeSignatureParameter(MAX_BUCKETS, BigintType.BIGINT), + buildTypeSignatureParameter(CAPACITY, BigintType.BIGINT), + buildTypeSignatureParameter(KEYS, new ArrayType(keyType)), + buildTypeSignatureParameter(VALUES, new ArrayType(BigintType.BIGINT)))); + Type outputType = functionAndTypeManager.getParameterizedType(MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(BigintType.BIGINT.getTypeSignature()))); + + DynamicClassLoader classLoader = new DynamicClassLoader(ApproximateMostFrequent.class.getClassLoader()); + List inputTypes = ImmutableList.of(keyType); + ApproximateMostFrequentStateSerializer stateSerializer = new ApproximateMostFrequentStateSerializer(keyType, serializedType); + MethodHandle inputFunction = INPUT_FUNCTION.bindTo(keyType); + + AggregationMetadata metadata = new AggregationMetadata( + generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), + createInputParameterMetadata(keyType), + inputFunction, + COMBINE_FUNCTION, + OUTPUT_FUNCTION, + ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor( + ApproximateMostFrequentState.class, + stateSerializer, + new ApproximateMostFrequentStateFactory())), + outputType); + + Class accumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + Accumulator.class, + metadata, + classLoader); + Class groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + GroupedAccumulator.class, + metadata, + classLoader); + return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(serializedType), outputType, + true, false, metadata, accumulatorClass, groupedAccumulatorClass); + } + + private TypeSignatureParameter buildTypeSignatureParameter(String fieldName, Type type) + { + return TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName(fieldName, false)), type.getTypeSignature())); + } + + private static List createInputParameterMetadata(Type keyType) + { + return ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE), + new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, BigintType.BIGINT), + new AggregationMetadata.ParameterMetadata(BLOCK_INPUT_CHANNEL, keyType), + new AggregationMetadata.ParameterMetadata(BLOCK_INDEX), + new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, BigintType.BIGINT)); + } + + @Override + public String getDescription() + { + return "Computes the top frequent elements approximately"; + } - @InputFunction - @TypeParameter("T") - public static void input( - @TypeParameter("T") Type type, - @AggregationState ApproximateMostFrequentState state, - @SqlType(BIGINT) long buckets, - @BlockPosition @SqlType("T") @NullablePosition Block valueBlock, - @BlockIndex int valueIndex, - @SqlType(BIGINT) long capacity) + public static void input(Type type, + ApproximateMostFrequentState state, + long buckets, + Block valueBlock, + int valueIndex, long capacity) { StreamSummary streamSummary = state.getStateSummary(); if (streamSummary == null) { checkCondition(buckets > 1, INVALID_FUNCTION_ARGUMENT, "approx_most_frequent bucket count must be greater than one, input bucket count: %s", buckets); - streamSummary = new StreamSummary(type, toIntExact(buckets), toIntExact(capacity)); + streamSummary = new StreamSummary( + type, + toIntExact(buckets), + toIntExact(capacity)); state.setStateSummary(streamSummary); } streamSummary.add(valueBlock, valueIndex, 1L); } - @CombineFunction - public static void combine( - @AggregationState ApproximateMostFrequentState state, - @AggregationState ApproximateMostFrequentState otherState) + public static void combine(ApproximateMostFrequentState state, ApproximateMostFrequentState otherState) { StreamSummary streamSummary = state.getStateSummary(); if (streamSummary == null) { @@ -86,8 +180,7 @@ public static void combine( } } - @OutputFunction("map(T,bigint)") - public static void output(@AggregationState ApproximateMostFrequentState state, BlockBuilder out) + public static void output(ApproximateMostFrequentState state, BlockBuilder out) { if (state.getStateSummary() == null) { out.appendNull(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java index 0ec5200d440cf..e1cca1926bde1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequentStateSerializer.java @@ -15,32 +15,20 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; -import com.facebook.presto.common.type.ArrayType; -import com.facebook.presto.common.type.BigintType; -import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.aggregation.approxmostfrequent.stream.StreamSummary; import com.facebook.presto.spi.function.AccumulatorStateSerializer; -import com.facebook.presto.spi.function.TypeParameter; -import com.google.common.collect.ImmutableList; public class ApproximateMostFrequentStateSerializer implements AccumulatorStateSerializer { private final Type type; private final Type serializedType; - private static final String MAX_BUCKETS = "max_buckets"; - private static final String CAPACITY = "capacity"; - private static final String KEYS = "keys"; - private static final String VALUES = "values"; - public ApproximateMostFrequentStateSerializer(@TypeParameter("T") Type type) + public ApproximateMostFrequentStateSerializer(Type type, Type serializedType) { this.type = type; - this.serializedType = RowType.from(ImmutableList.of(RowType.field(MAX_BUCKETS, BigintType.BIGINT), - RowType.field(CAPACITY, BigintType.BIGINT), - RowType.field(KEYS, new ArrayType(type)), - RowType.field(VALUES, new ArrayType(BigintType.BIGINT)))); + this.serializedType = serializedType; } @Override