diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 8dfa8bc6f9b99..11fb91965fec9 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -944,7 +944,7 @@ private List getBuildInFunctions(FeaturesConfig featuresC .scalars(SetDigestFunctions.class) .scalars(SetDigestOperators.class) .aggregates(MergeKHyperLogLogAggregationFunction.class) - .aggregates(KHyperLogLogAggregationFunction.class) + .function(new KHyperLogLogAggregationFunction(featuresConfig.getKHyperLogLogAggregationGroupNumberLimit())) .scalars(KHyperLogLogFunctions.class) .scalars(KHyperLogLogOperators.class) .scalars(WilsonInterval.class) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index c1f0675c7dbbe..54072b1e3832e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -295,6 +295,7 @@ public class FeaturesConfig private boolean removeRedundantCastToVarcharInJoin = true; private boolean skipHashGenerationForJoinWithTableScanInput; + private long kHyperLogLogAggregationGroupNumberLimit; public enum PartitioningPrecisionStrategy { @@ -2957,4 +2958,17 @@ public FeaturesConfig setSkipHashGenerationForJoinWithTableScanInput(boolean ski this.skipHashGenerationForJoinWithTableScanInput = skipHashGenerationForJoinWithTableScanInput; return this; } + + public long getKHyperLogLogAggregationGroupNumberLimit() + { + return kHyperLogLogAggregationGroupNumberLimit; + } + + @Config("khyperloglog-agg-group-limit") + @ConfigDescription("Maximum number of groups for khyperloglog_agg per task") + public FeaturesConfig setKHyperLogLogAggregationGroupNumberLimit(long kHyperLogLogAggregationGroupNumberLimit) + { + this.kHyperLogLogAggregationGroupNumberLimit = kHyperLogLogAggregationGroupNumberLimit; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogAggregationFunction.java index 666e1d738d7be..b48a18865d884 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogAggregationFunction.java @@ -11,30 +11,141 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.facebook.presto.type.khyperloglog; +import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.common.block.BlockBuilder; -import com.facebook.presto.common.type.StandardTypes; -import com.facebook.presto.spi.function.AggregationFunction; -import com.facebook.presto.spi.function.AggregationState; -import com.facebook.presto.spi.function.CombineFunction; -import com.facebook.presto.spi.function.InputFunction; -import com.facebook.presto.spi.function.LiteralParameters; -import com.facebook.presto.spi.function.OutputFunction; -import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.SqlAggregationFunction; +import com.facebook.presto.operator.aggregation.AccumulatorCompiler; +import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AccumulatorState; +import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.function.aggregation.Accumulator; +import com.facebook.presto.spi.function.aggregation.AggregationMetadata; +import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; +import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; -@AggregationFunction("khyperloglog_agg") +import java.lang.invoke.MethodHandle; +import java.util.List; + +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.Signature.typeVariable; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; +import static com.facebook.presto.type.khyperloglog.KHyperLogLogType.K_HYPER_LOG_LOG; +import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.collect.ImmutableList.toImmutableList; + public final class KHyperLogLogAggregationFunction + extends SqlAggregationFunction { + private static final String NAME = "khyperloglog_agg"; private static final KHyperLogLogStateSerializer SERIALIZER = new KHyperLogLogStateSerializer(); + private static final MethodHandle LONG_LONG_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, long.class, long.class); + private static final MethodHandle SLICE_LONG_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, Slice.class, long.class); + private static final MethodHandle DOUBLE_LONG_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, double.class, long.class); + private static final MethodHandle LONG_SLICE_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, long.class, Slice.class); + private static final MethodHandle SLICE_SLICE_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, Slice.class, Slice.class); + private static final MethodHandle DOUBLE_SLICE_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, double.class, Slice.class); + private static final MethodHandle OUTPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "output", KHyperLogLogState.class, BlockBuilder.class); + private static final MethodHandle COMBINE_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "combine", KHyperLogLogState.class, KHyperLogLogState.class); + private final long groupLimit; + + public KHyperLogLogAggregationFunction(long groupLimit) + { + super(NAME, ImmutableList.of(typeVariable("E"), typeVariable("T")), ImmutableList.of(), K_HYPER_LOG_LOG.getTypeSignature(), ImmutableList.of(parseTypeSignature("E"), parseTypeSignature("T"))); + this.groupLimit = groupLimit; + } + + public static String getFunctionName() + { + return NAME; + } + + @Override + public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) + { + Type firstInputType = boundVariables.getTypeVariable("E"); + Type secondInputType = boundVariables.getTypeVariable("T"); + return generateAggregation(firstInputType, secondInputType); + } + + private BuiltInAggregationFunctionImplementation generateAggregation(Type firstInputType, Type secondInputType) + { + DynamicClassLoader classLoader = new DynamicClassLoader(KHyperLogLogAggregationFunction.class.getClassLoader()); + List inputTypes = ImmutableList.of(firstInputType, secondInputType); + Class stateInterface = KHyperLogLogState.class; + AccumulatorStateSerializer stateSerializer = new KHyperLogLogStateSerializer(); + MethodHandle inputFunction = getMethodHandle(firstInputType, secondInputType); + + AggregationMetadata metadata = new AggregationMetadata( + generateAggregationName(NAME, K_HYPER_LOG_LOG.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), + ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE), new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, firstInputType), new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, secondInputType)), + inputFunction, + COMBINE_FUNCTION, + OUTPUT_FUNCTION, + ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor( + stateInterface, + stateSerializer, + new KHyperLogLogStateFactory(groupLimit))), + K_HYPER_LOG_LOG); + + Type intermediateType = stateSerializer.getSerializedType(); - private KHyperLogLogAggregationFunction() {} + Class accumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + Accumulator.class, + metadata, + classLoader); + Class groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + GroupedAccumulator.class, + metadata, + classLoader); + return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(intermediateType), K_HYPER_LOG_LOG, + true, false, metadata, accumulatorClass, groupedAccumulatorClass); + } + + private static MethodHandle getMethodHandle(Type firstInputType, Type secondInputType) + { + MethodHandle inputFunction; + if (firstInputType.getJavaType() == long.class && secondInputType.getJavaType() == long.class) { + inputFunction = LONG_LONG_INPUT_FUNCTION; + } + else if (firstInputType.getJavaType() == Slice.class && secondInputType.getJavaType() == long.class) { + inputFunction = SLICE_LONG_INPUT_FUNCTION; + } + else if (firstInputType.getJavaType() == double.class && secondInputType.getJavaType() == long.class) { + inputFunction = DOUBLE_LONG_INPUT_FUNCTION; + } + else if (firstInputType.getJavaType() == long.class && secondInputType.getJavaType() == Slice.class) { + inputFunction = LONG_SLICE_INPUT_FUNCTION; + } + else if (firstInputType.getJavaType() == Slice.class && secondInputType.getJavaType() == Slice.class) { + inputFunction = SLICE_SLICE_INPUT_FUNCTION; + } + else if (firstInputType.getJavaType() == double.class && secondInputType.getJavaType() == Slice.class) { + inputFunction = DOUBLE_SLICE_INPUT_FUNCTION; + } + else { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "input types for khyperloglog_agg are not supported"); + } + return inputFunction; + } + + @Override + public String getDescription() + { + return "Returns the KHyperLogLog sketch that represents the relationship between columns x and y. The MinHash structure summarizes x and the HyperLogLog sketches represent y values linked to x values."; + } - @InputFunction - public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.BIGINT) long uii) + public static void input(KHyperLogLogState state, long value, long uii) { if (state.getKHLL() == null) { state.setKHLL(new KHyperLogLog()); @@ -42,9 +153,7 @@ public static void input(@AggregationState KHyperLogLogState state, @SqlType(Sta state.getKHLL().add(value, uii); } - @InputFunction - @LiteralParameters("x") - public static void input(@AggregationState KHyperLogLogState state, @SqlType("varchar(x)") Slice value, @SqlType(StandardTypes.BIGINT) long uii) + public static void input(KHyperLogLogState state, Slice value, long uii) { if (state.getKHLL() == null) { state.setKHLL(new KHyperLogLog()); @@ -52,35 +161,27 @@ public static void input(@AggregationState KHyperLogLogState state, @SqlType("va state.getKHLL().add(value, uii); } - @InputFunction - public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType(StandardTypes.BIGINT) long uii) + public static void input(KHyperLogLogState state, double value, long uii) { input(state, Double.doubleToLongBits(value), uii); } - @InputFunction - @LiteralParameters("x") - public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.BIGINT) long value, @SqlType("varchar(x)") Slice uii) + public static void input(KHyperLogLogState state, long value, Slice uii) { input(state, value, XxHash64.hash(uii)); } - @InputFunction - @LiteralParameters({"x", "y"}) - public static void input(@AggregationState KHyperLogLogState state, @SqlType("varchar(x)") Slice value, @SqlType("varchar(y)") Slice uii) + public static void input(KHyperLogLogState state, Slice value, Slice uii) { input(state, value, XxHash64.hash(uii)); } - @InputFunction - @LiteralParameters("x") - public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType("varchar(x)") Slice uii) + public static void input(KHyperLogLogState state, double value, Slice uii) { input(state, Double.doubleToLongBits(value), XxHash64.hash(uii)); } - @CombineFunction - public static void combine(@AggregationState KHyperLogLogState state, @AggregationState KHyperLogLogState otherState) + public static void combine(KHyperLogLogState state, KHyperLogLogState otherState) { if (state.getKHLL() == null) { KHyperLogLog copy = new KHyperLogLog(); @@ -92,8 +193,7 @@ public static void combine(@AggregationState KHyperLogLogState state, @Aggregati } } - @OutputFunction(KHyperLogLogType.NAME) - public static void output(@AggregationState KHyperLogLogState state, BlockBuilder out) + public static void output(KHyperLogLogState state, BlockBuilder out) { SERIALIZER.serialize(state, out); } diff --git a/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogStateFactory.java b/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogStateFactory.java index 5a2b68119d97e..b14e98956510e 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/type/khyperloglog/KHyperLogLogStateFactory.java @@ -15,16 +15,27 @@ package com.facebook.presto.type.khyperloglog; import com.facebook.presto.common.array.ObjectBigArray; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.GroupedAccumulatorState; import org.openjdk.jol.info.ClassLayout; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static java.lang.String.format; + public class KHyperLogLogStateFactory implements AccumulatorStateFactory { private static final int SIZE_OF_SINGLE = ClassLayout.parseClass(SingleKHyperLogLogState.class).instanceSize(); private static final int SIZE_OF_GROUPED = ClassLayout.parseClass(GroupedKHyperLogLogState.class).instanceSize(); + private final long groupLimit; + + public KHyperLogLogStateFactory(long groupLimit) + { + this.groupLimit = groupLimit; + } + @Override public KHyperLogLogState createSingleState() { @@ -40,7 +51,7 @@ public Class getSingleStateClass() @Override public KHyperLogLogState createGroupedState() { - return new GroupedKHyperLogLogState(); + return new GroupedKHyperLogLogState(groupLimit); } @Override @@ -55,6 +66,12 @@ public static class GroupedKHyperLogLogState private final ObjectBigArray khlls = new ObjectBigArray<>(); private long groupId; private long size; + private final long groupLimit; + + public GroupedKHyperLogLogState(long groupLimit) + { + this.groupLimit = groupLimit; + } @Override public void setGroupId(long groupId) @@ -65,6 +82,9 @@ public void setGroupId(long groupId) @Override public void ensureCapacity(long size) { + if (groupLimit > 0 && size > groupLimit) { + throw new PrestoException(NOT_SUPPORTED, format("GroupedKHyperLogLogState number of groups exceed limit %d set by khyperloglog-agg-group-limit", groupLimit)); + } khlls.ensureCapacity(size); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index c2989699ae9ea..5d1c754221782 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -258,7 +258,8 @@ public void testDefaults() .setRemoveRedundantCastToVarcharInJoin(true) .setHandleComplexEquiJoins(false) .setSkipHashGenerationForJoinWithTableScanInput(false) - .setCteMaterializationStrategy(CteMaterializationStrategy.NONE)); + .setCteMaterializationStrategy(CteMaterializationStrategy.NONE) + .setKHyperLogLogAggregationGroupNumberLimit(0)); } @Test @@ -463,6 +464,7 @@ public void testExplicitPropertyMappings() .put("cte-materialization-strategy", "ALL") .put("optimizer.handle-complex-equi-joins", "true") .put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true") + .put("khyperloglog-agg-group-limit", "1000") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -664,7 +666,8 @@ public void testExplicitPropertyMappings() .setRemoveRedundantCastToVarcharInJoin(false) .setHandleComplexEquiJoins(true) .setSkipHashGenerationForJoinWithTableScanInput(true) - .setCteMaterializationStrategy(CteMaterializationStrategy.ALL); + .setCteMaterializationStrategy(CteMaterializationStrategy.ALL) + .setKHyperLogLogAggregationGroupNumberLimit(1000); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/type/khyperloglog/TestKHyperLogLogAggregationFunction.java b/presto-main/src/test/java/com/facebook/presto/type/khyperloglog/TestKHyperLogLogAggregationFunction.java index b355bde2f1de4..0e3dd93389ce5 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/khyperloglog/TestKHyperLogLogAggregationFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/type/khyperloglog/TestKHyperLogLogAggregationFunction.java @@ -18,7 +18,6 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -42,7 +41,7 @@ public class TestKHyperLogLogAggregationFunction { private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager(); - private static final String NAME = KHyperLogLogAggregationFunction.class.getAnnotation(AggregationFunction.class).value(); + private static final String NAME = KHyperLogLogAggregationFunction.getFunctionName(); @Test public void testSimpleKHyperLogLog()