From fd023c52cc0b58524bcd76513e68e6d138633eca Mon Sep 17 00:00:00 2001 From: feilong-liu <51964150+feilong-liu@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:16:41 -0800 Subject: [PATCH] Revert "Refactor `ApproximateMostFrequent` to use type annotations" --- ...uiltInTypeAndFunctionNamespaceManager.java | 4 +- .../ApproximateMostFrequent.java | 155 ++++++++++++++---- ...pproximateMostFrequentStateSerializer.java | 16 +- 3 files changed, 128 insertions(+), 47 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 76cd8feac8194..11fb91965fec9 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -88,7 +88,6 @@ import com.facebook.presto.operator.aggregation.ReduceAggregationFunction; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; import com.facebook.presto.operator.aggregation.VarianceAggregation; -import com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent; import com.facebook.presto.operator.aggregation.arrayagg.ArrayAggregationFunction; import com.facebook.presto.operator.aggregation.arrayagg.SetAggregationFunction; import com.facebook.presto.operator.aggregation.arrayagg.SetUnionFunction; @@ -354,6 +353,7 @@ import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT_AND_COMPRESSION; +import static com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent.APPROXIMATE_MOST_FREQUENT; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMaxByAggregationFunction.ALTERNATIVE_MAX_BY; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMinByAggregationFunction.ALTERNATIVE_MIN_BY; import static com.facebook.presto.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY; @@ -937,7 +937,7 @@ private List extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC .functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION) .functions(MAP_TRANSFORM_KEY_FUNCTION, MAP_TRANSFORM_VALUE_FUNCTION) .function(TRY_CAST) - .aggregate(ApproximateMostFrequent.class) + .function(APPROXIMATE_MOST_FREQUENT) .function(K_DISTINCT) .aggregate(MergeSetDigestAggregation.class) .aggregate(BuildSetDigestAggregation.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java index f762028e48ee5..759af63a810bd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/approxmostfrequent/ApproximateMostFrequent.java @@ -13,25 +13,45 @@ */ package com.facebook.presto.operator.aggregation.approxmostfrequent; +import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.NamedTypeSignature; +import com.facebook.presto.common.type.RowFieldName; import com.facebook.presto.common.type.Type; -import com.facebook.presto.operator.aggregation.NullablePosition; +import com.facebook.presto.common.type.TypeSignatureParameter; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.SqlAggregationFunction; +import com.facebook.presto.operator.aggregation.AccumulatorCompiler; +import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; import com.facebook.presto.operator.aggregation.approxmostfrequent.stream.StreamSummary; -import com.facebook.presto.spi.function.AggregationFunction; -import com.facebook.presto.spi.function.AggregationState; -import com.facebook.presto.spi.function.BlockIndex; -import com.facebook.presto.spi.function.BlockPosition; -import com.facebook.presto.spi.function.CombineFunction; -import com.facebook.presto.spi.function.Description; -import com.facebook.presto.spi.function.InputFunction; -import com.facebook.presto.spi.function.OutputFunction; -import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.function.aggregation.Accumulator; +import com.facebook.presto.spi.function.aggregation.AggregationMetadata; +import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; +import com.google.common.collect.ImmutableList; + +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.Optional; import static com.facebook.presto.common.type.StandardTypes.BIGINT; +import static com.facebook.presto.common.type.StandardTypes.MAP; +import static com.facebook.presto.common.type.StandardTypes.ROW; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.Signature.comparableTypeParameter; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.util.Failures.checkCondition; +import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Math.toIntExact; /** @@ -46,36 +66,110 @@ * by Ahmed Metwally, Divyakant Agrawal, and Amr El Abbadi *
*/ -@AggregationFunction(value = "approx_most_frequent", isCalledOnNullInput = true) -@Description("Computes the top frequent elements approximately") public final class ApproximateMostFrequent + extends SqlAggregationFunction { - private ApproximateMostFrequent() - {} + public static final ApproximateMostFrequent APPROXIMATE_MOST_FREQUENT = new ApproximateMostFrequent(); + public static final String NAME = "approx_most_frequent"; + private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "output", ApproximateMostFrequentState.class, BlockBuilder.class); + private static final MethodHandle INPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "input", Type.class, ApproximateMostFrequentState.class, long.class, Block.class, int.class, long.class); + private static final MethodHandle COMBINE_FUNCTION = methodHandle(ApproximateMostFrequent.class, "combine", ApproximateMostFrequentState.class, ApproximateMostFrequentState.class); + private static final String MAX_BUCKETS = "max_buckets"; + private static final String CAPACITY = "capacity"; + private static final String KEYS = "keys"; + private static final String VALUES = "values"; + + protected ApproximateMostFrequent() + { + super(NAME, + ImmutableList.of(comparableTypeParameter("K")), + ImmutableList.of(), + parseTypeSignature("map(K,bigint)"), + ImmutableList.of(parseTypeSignature(BIGINT), parseTypeSignature("K"), parseTypeSignature(BIGINT))); + } + + @Override + public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) + { + Type keyType = boundVariables.getTypeVariable("K"); + checkArgument(keyType.isComparable(), "keyType must be comparable"); + Type serializedType = functionAndTypeManager.getParameterizedType(ROW, ImmutableList.of( + buildTypeSignatureParameter(MAX_BUCKETS, BigintType.BIGINT), + buildTypeSignatureParameter(CAPACITY, BigintType.BIGINT), + buildTypeSignatureParameter(KEYS, new ArrayType(keyType)), + buildTypeSignatureParameter(VALUES, new ArrayType(BigintType.BIGINT)))); + Type outputType = functionAndTypeManager.getParameterizedType(MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(BigintType.BIGINT.getTypeSignature()))); + + DynamicClassLoader classLoader = new DynamicClassLoader(ApproximateMostFrequent.class.getClassLoader()); + List