Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import com.facebook.presto.operator.aggregation.ReduceAggregationFunction;
import com.facebook.presto.operator.aggregation.SumDataSizeForStats;
import com.facebook.presto.operator.aggregation.VarianceAggregation;
import com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent;
import com.facebook.presto.operator.aggregation.arrayagg.ArrayAggregationFunction;
import com.facebook.presto.operator.aggregation.arrayagg.SetAggregationFunction;
import com.facebook.presto.operator.aggregation.arrayagg.SetUnionFunction;
Expand Down Expand Up @@ -353,7 +354,6 @@
import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG;
import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT;
import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT_AND_COMPRESSION;
import static com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent.APPROXIMATE_MOST_FREQUENT;
import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMaxByAggregationFunction.ALTERNATIVE_MAX_BY;
import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMinByAggregationFunction.ALTERNATIVE_MIN_BY;
import static com.facebook.presto.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY;
Expand Down Expand Up @@ -937,7 +937,7 @@ private List<? extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC
.functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION)
.functions(MAP_TRANSFORM_KEY_FUNCTION, MAP_TRANSFORM_VALUE_FUNCTION)
.function(TRY_CAST)
.function(APPROXIMATE_MOST_FREQUENT)
.aggregate(ApproximateMostFrequent.class)
.function(K_DISTINCT)
.aggregate(MergeSetDigestAggregation.class)
.aggregate(BuildSetDigestAggregation.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,25 @@
*/
package com.facebook.presto.operator.aggregation.approxmostfrequent;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.NamedTypeSignature;
import com.facebook.presto.common.type.RowFieldName;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AccumulatorCompiler;
import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation;
import com.facebook.presto.operator.aggregation.NullablePosition;
import com.facebook.presto.operator.aggregation.approxmostfrequent.stream.StreamSummary;
import com.facebook.presto.spi.function.aggregation.Accumulator;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.google.common.collect.ImmutableList;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.BlockIndex;
import com.facebook.presto.spi.function.BlockPosition;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

import static com.facebook.presto.common.type.StandardTypes.BIGINT;
import static com.facebook.presto.common.type.StandardTypes.MAP;
import static com.facebook.presto.common.type.StandardTypes.ROW;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.function.Signature.comparableTypeParameter;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.util.Failures.checkCondition;
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Math.toIntExact;

/**
Expand All @@ -66,110 +46,36 @@
* by Ahmed Metwally, Divyakant Agrawal, and Amr El Abbadi
* </p>
*/
@AggregationFunction(value = "approx_most_frequent", isCalledOnNullInput = true)
@Description("Computes the top frequent elements approximately")
public final class ApproximateMostFrequent
extends SqlAggregationFunction
{
public static final ApproximateMostFrequent APPROXIMATE_MOST_FREQUENT = new ApproximateMostFrequent();
public static final String NAME = "approx_most_frequent";
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "output", ApproximateMostFrequentState.class, BlockBuilder.class);
private static final MethodHandle INPUT_FUNCTION = methodHandle(ApproximateMostFrequent.class, "input", Type.class, ApproximateMostFrequentState.class, long.class, Block.class, int.class, long.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(ApproximateMostFrequent.class, "combine", ApproximateMostFrequentState.class, ApproximateMostFrequentState.class);
private static final String MAX_BUCKETS = "max_buckets";
private static final String CAPACITY = "capacity";
private static final String KEYS = "keys";
private static final String VALUES = "values";

protected ApproximateMostFrequent()
{
super(NAME,
ImmutableList.of(comparableTypeParameter("K")),
ImmutableList.of(),
parseTypeSignature("map(K,bigint)"),
ImmutableList.of(parseTypeSignature(BIGINT), parseTypeSignature("K"), parseTypeSignature(BIGINT)));
}

@Override
public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
{
Type keyType = boundVariables.getTypeVariable("K");
checkArgument(keyType.isComparable(), "keyType must be comparable");
Type serializedType = functionAndTypeManager.getParameterizedType(ROW, ImmutableList.of(
buildTypeSignatureParameter(MAX_BUCKETS, BigintType.BIGINT),
buildTypeSignatureParameter(CAPACITY, BigintType.BIGINT),
buildTypeSignatureParameter(KEYS, new ArrayType(keyType)),
buildTypeSignatureParameter(VALUES, new ArrayType(BigintType.BIGINT))));
Type outputType = functionAndTypeManager.getParameterizedType(MAP, ImmutableList.of(
TypeSignatureParameter.of(keyType.getTypeSignature()),
TypeSignatureParameter.of(BigintType.BIGINT.getTypeSignature())));

DynamicClassLoader classLoader = new DynamicClassLoader(ApproximateMostFrequent.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(keyType);
ApproximateMostFrequentStateSerializer stateSerializer = new ApproximateMostFrequentStateSerializer(keyType, serializedType);
MethodHandle inputFunction = INPUT_FUNCTION.bindTo(keyType);

AggregationMetadata metadata = new AggregationMetadata(
generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
createInputParameterMetadata(keyType),
inputFunction,
COMBINE_FUNCTION,
OUTPUT_FUNCTION,
ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor(
ApproximateMostFrequentState.class,
stateSerializer,
new ApproximateMostFrequentStateFactory())),
outputType);

Class<? extends Accumulator> accumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
Accumulator.class,
metadata,
classLoader);
Class<? extends GroupedAccumulator> groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
GroupedAccumulator.class,
metadata,
classLoader);
return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(serializedType), outputType,
true, false, metadata, accumulatorClass, groupedAccumulatorClass);
}

private TypeSignatureParameter buildTypeSignatureParameter(String fieldName, Type type)
{
return TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName(fieldName, false)), type.getTypeSignature()));
}

private static List<AggregationMetadata.ParameterMetadata> createInputParameterMetadata(Type keyType)
{
return ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE),
new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, BigintType.BIGINT),
new AggregationMetadata.ParameterMetadata(BLOCK_INPUT_CHANNEL, keyType),
new AggregationMetadata.ParameterMetadata(BLOCK_INDEX),
new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, BigintType.BIGINT));
}

@Override
public String getDescription()
{
return "Computes the top frequent elements approximately";
}
private ApproximateMostFrequent()
{}

public static void input(Type type,
ApproximateMostFrequentState state,
long buckets,
Block valueBlock,
int valueIndex, long capacity)
@InputFunction
@TypeParameter("T")
public static void input(
@TypeParameter("T") Type type,
@AggregationState ApproximateMostFrequentState state,
@SqlType(BIGINT) long buckets,
@BlockPosition @SqlType("T") @NullablePosition Block valueBlock,
@BlockIndex int valueIndex,
@SqlType(BIGINT) long capacity)
{
StreamSummary streamSummary = state.getStateSummary();
if (streamSummary == null) {
checkCondition(buckets > 1, INVALID_FUNCTION_ARGUMENT, "approx_most_frequent bucket count must be greater than one, input bucket count: %s", buckets);
streamSummary = new StreamSummary(
type,
toIntExact(buckets),
toIntExact(capacity));
streamSummary = new StreamSummary(type, toIntExact(buckets), toIntExact(capacity));
state.setStateSummary(streamSummary);
}
streamSummary.add(valueBlock, valueIndex, 1L);
}

public static void combine(ApproximateMostFrequentState state, ApproximateMostFrequentState otherState)
@CombineFunction
public static void combine(
@AggregationState ApproximateMostFrequentState state,
@AggregationState ApproximateMostFrequentState otherState)
{
StreamSummary streamSummary = state.getStateSummary();
if (streamSummary == null) {
Expand All @@ -180,7 +86,8 @@ public static void combine(ApproximateMostFrequentState state, ApproximateMostFr
}
}

public static void output(ApproximateMostFrequentState state, BlockBuilder out)
@OutputFunction("map(T,bigint)")
public static void output(@AggregationState ApproximateMostFrequentState state, BlockBuilder out)
{
if (state.getStateSummary() == null) {
out.appendNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,32 @@

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.approxmostfrequent.stream.StreamSummary;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.TypeParameter;
import com.google.common.collect.ImmutableList;

public class ApproximateMostFrequentStateSerializer
implements AccumulatorStateSerializer<ApproximateMostFrequentState>
{
private final Type type;
private final Type serializedType;
private static final String MAX_BUCKETS = "max_buckets";
private static final String CAPACITY = "capacity";
private static final String KEYS = "keys";
private static final String VALUES = "values";

public ApproximateMostFrequentStateSerializer(Type type, Type serializedType)
public ApproximateMostFrequentStateSerializer(@TypeParameter("T") Type type)
{
this.type = type;
this.serializedType = serializedType;
this.serializedType = RowType.from(ImmutableList.of(RowType.field(MAX_BUCKETS, BigintType.BIGINT),
RowType.field(CAPACITY, BigintType.BIGINT),
RowType.field(KEYS, new ArrayType(type)),
RowType.field(VALUES, new ArrayType(BigintType.BIGINT))));
}

@Override
Expand Down