Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ private List<? extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC
.scalars(SetDigestFunctions.class)
.scalars(SetDigestOperators.class)
.aggregates(MergeKHyperLogLogAggregationFunction.class)
.aggregates(KHyperLogLogAggregationFunction.class)
.function(new KHyperLogLogAggregationFunction(featuresConfig.getKHyperLogLogAggregationGroupNumberLimit()))
.scalars(KHyperLogLogFunctions.class)
.scalars(KHyperLogLogOperators.class)
.scalars(WilsonInterval.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ public class FeaturesConfig

private boolean removeRedundantCastToVarcharInJoin = true;
private boolean skipHashGenerationForJoinWithTableScanInput;
private long kHyperLogLogAggregationGroupNumberLimit;

public enum PartitioningPrecisionStrategy
{
Expand Down Expand Up @@ -2957,4 +2958,17 @@ public FeaturesConfig setSkipHashGenerationForJoinWithTableScanInput(boolean ski
this.skipHashGenerationForJoinWithTableScanInput = skipHashGenerationForJoinWithTableScanInput;
return this;
}

public long getKHyperLogLogAggregationGroupNumberLimit()
{
return kHyperLogLogAggregationGroupNumberLimit;
}

@Config("khyperloglog-agg-group-limit")
@ConfigDescription("Maximum number of groups for khyperloglog_agg per task")
public FeaturesConfig setKHyperLogLogAggregationGroupNumberLimit(long kHyperLogLogAggregationGroupNumberLimit)
{
this.kHyperLogLogAggregationGroupNumberLimit = kHyperLogLogAggregationGroupNumberLimit;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,76 +11,177 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.type.khyperloglog;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.LiteralParameters;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AccumulatorCompiler;
import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AccumulatorState;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.aggregation.Accumulator;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.XxHash64;

@AggregationFunction("khyperloglog_agg")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain more about these changes, is this some new style of writing agg functions ?

Copy link
Copy Markdown
Contributor Author

@feilong-liu feilong-liu Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an easier way to write aggregation function, corresponding to the ParametricAggregation class. The input, combine and output function together with the annotations can be used to generate the aggregation function. However, it's also not flexible. As we need to specify the limit on the number of groups, I choose to refactor the code to override functions directly. Both ways are widely used in our codebase.

import java.lang.invoke.MethodHandle;
import java.util.List;

import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.function.Signature.typeVariable;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.type.khyperloglog.KHyperLogLogType.K_HYPER_LOG_LOG;
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.collect.ImmutableList.toImmutableList;

public final class KHyperLogLogAggregationFunction
extends SqlAggregationFunction
{
private static final String NAME = "khyperloglog_agg";
private static final KHyperLogLogStateSerializer SERIALIZER = new KHyperLogLogStateSerializer();
private static final MethodHandle LONG_LONG_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, long.class, long.class);
private static final MethodHandle SLICE_LONG_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, Slice.class, long.class);
private static final MethodHandle DOUBLE_LONG_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, double.class, long.class);
private static final MethodHandle LONG_SLICE_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, long.class, Slice.class);
private static final MethodHandle SLICE_SLICE_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, Slice.class, Slice.class);
private static final MethodHandle DOUBLE_SLICE_INPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "input", KHyperLogLogState.class, double.class, Slice.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "output", KHyperLogLogState.class, BlockBuilder.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(KHyperLogLogAggregationFunction.class, "combine", KHyperLogLogState.class, KHyperLogLogState.class);
private final long groupLimit;

public KHyperLogLogAggregationFunction(long groupLimit)
{
super(NAME, ImmutableList.of(typeVariable("E"), typeVariable("T")), ImmutableList.of(), K_HYPER_LOG_LOG.getTypeSignature(), ImmutableList.of(parseTypeSignature("E"), parseTypeSignature("T")));
this.groupLimit = groupLimit;
}

public static String getFunctionName()
{
return NAME;
}

@Override
public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
{
Type firstInputType = boundVariables.getTypeVariable("E");
Type secondInputType = boundVariables.getTypeVariable("T");
return generateAggregation(firstInputType, secondInputType);
}

private BuiltInAggregationFunctionImplementation generateAggregation(Type firstInputType, Type secondInputType)
{
DynamicClassLoader classLoader = new DynamicClassLoader(KHyperLogLogAggregationFunction.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(firstInputType, secondInputType);
Class<? extends AccumulatorState> stateInterface = KHyperLogLogState.class;
AccumulatorStateSerializer<?> stateSerializer = new KHyperLogLogStateSerializer();
MethodHandle inputFunction = getMethodHandle(firstInputType, secondInputType);

AggregationMetadata metadata = new AggregationMetadata(
generateAggregationName(NAME, K_HYPER_LOG_LOG.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE), new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, firstInputType), new AggregationMetadata.ParameterMetadata(INPUT_CHANNEL, secondInputType)),
inputFunction,
COMBINE_FUNCTION,
OUTPUT_FUNCTION,
ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor(
stateInterface,
stateSerializer,
new KHyperLogLogStateFactory(groupLimit))),
K_HYPER_LOG_LOG);

Type intermediateType = stateSerializer.getSerializedType();

private KHyperLogLogAggregationFunction() {}
Class<? extends Accumulator> accumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
Accumulator.class,
metadata,
classLoader);
Class<? extends GroupedAccumulator> groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
GroupedAccumulator.class,
metadata,
classLoader);
return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(intermediateType), K_HYPER_LOG_LOG,
true, false, metadata, accumulatorClass, groupedAccumulatorClass);
}

private static MethodHandle getMethodHandle(Type firstInputType, Type secondInputType)
{
MethodHandle inputFunction;
if (firstInputType.getJavaType() == long.class && secondInputType.getJavaType() == long.class) {
inputFunction = LONG_LONG_INPUT_FUNCTION;
}
else if (firstInputType.getJavaType() == Slice.class && secondInputType.getJavaType() == long.class) {
inputFunction = SLICE_LONG_INPUT_FUNCTION;
}
else if (firstInputType.getJavaType() == double.class && secondInputType.getJavaType() == long.class) {
inputFunction = DOUBLE_LONG_INPUT_FUNCTION;
}
else if (firstInputType.getJavaType() == long.class && secondInputType.getJavaType() == Slice.class) {
inputFunction = LONG_SLICE_INPUT_FUNCTION;
}
else if (firstInputType.getJavaType() == Slice.class && secondInputType.getJavaType() == Slice.class) {
inputFunction = SLICE_SLICE_INPUT_FUNCTION;
}
else if (firstInputType.getJavaType() == double.class && secondInputType.getJavaType() == Slice.class) {
inputFunction = DOUBLE_SLICE_INPUT_FUNCTION;
}
else {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "input types for khyperloglog_agg are not supported");
}
return inputFunction;
}

@Override
public String getDescription()
{
return "Returns the KHyperLogLog sketch that represents the relationship between columns x and y. The MinHash structure summarizes x and the HyperLogLog sketches represent y values linked to x values.";
}

@InputFunction
public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.BIGINT) long uii)
public static void input(KHyperLogLogState state, long value, long uii)
{
if (state.getKHLL() == null) {
state.setKHLL(new KHyperLogLog());
}
state.getKHLL().add(value, uii);
}

@InputFunction
@LiteralParameters("x")
public static void input(@AggregationState KHyperLogLogState state, @SqlType("varchar(x)") Slice value, @SqlType(StandardTypes.BIGINT) long uii)
public static void input(KHyperLogLogState state, Slice value, long uii)
{
if (state.getKHLL() == null) {
state.setKHLL(new KHyperLogLog());
}
state.getKHLL().add(value, uii);
}

@InputFunction
public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType(StandardTypes.BIGINT) long uii)
public static void input(KHyperLogLogState state, double value, long uii)
{
input(state, Double.doubleToLongBits(value), uii);
}

@InputFunction
@LiteralParameters("x")
public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.BIGINT) long value, @SqlType("varchar(x)") Slice uii)
public static void input(KHyperLogLogState state, long value, Slice uii)
{
input(state, value, XxHash64.hash(uii));
}

@InputFunction
@LiteralParameters({"x", "y"})
public static void input(@AggregationState KHyperLogLogState state, @SqlType("varchar(x)") Slice value, @SqlType("varchar(y)") Slice uii)
public static void input(KHyperLogLogState state, Slice value, Slice uii)
{
input(state, value, XxHash64.hash(uii));
}

@InputFunction
@LiteralParameters("x")
public static void input(@AggregationState KHyperLogLogState state, @SqlType(StandardTypes.DOUBLE) double value, @SqlType("varchar(x)") Slice uii)
public static void input(KHyperLogLogState state, double value, Slice uii)
{
input(state, Double.doubleToLongBits(value), XxHash64.hash(uii));
}

@CombineFunction
public static void combine(@AggregationState KHyperLogLogState state, @AggregationState KHyperLogLogState otherState)
public static void combine(KHyperLogLogState state, KHyperLogLogState otherState)
{
if (state.getKHLL() == null) {
KHyperLogLog copy = new KHyperLogLog();
Expand All @@ -92,8 +193,7 @@ public static void combine(@AggregationState KHyperLogLogState state, @Aggregati
}
}

@OutputFunction(KHyperLogLogType.NAME)
public static void output(@AggregationState KHyperLogLogState state, BlockBuilder out)
public static void output(KHyperLogLogState state, BlockBuilder out)
{
SERIALIZER.serialize(state, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,27 @@
package com.facebook.presto.type.khyperloglog;

import com.facebook.presto.common.array.ObjectBigArray;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.GroupedAccumulatorState;
import org.openjdk.jol.info.ClassLayout;

import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static java.lang.String.format;

public class KHyperLogLogStateFactory
implements AccumulatorStateFactory<KHyperLogLogState>
{
private static final int SIZE_OF_SINGLE = ClassLayout.parseClass(SingleKHyperLogLogState.class).instanceSize();
private static final int SIZE_OF_GROUPED = ClassLayout.parseClass(GroupedKHyperLogLogState.class).instanceSize();

private final long groupLimit;

public KHyperLogLogStateFactory(long groupLimit)
{
this.groupLimit = groupLimit;
}

@Override
public KHyperLogLogState createSingleState()
{
Expand All @@ -40,7 +51,7 @@ public Class<? extends KHyperLogLogState> getSingleStateClass()
@Override
public KHyperLogLogState createGroupedState()
{
return new GroupedKHyperLogLogState();
return new GroupedKHyperLogLogState(groupLimit);
}

@Override
Expand All @@ -55,6 +66,12 @@ public static class GroupedKHyperLogLogState
private final ObjectBigArray<KHyperLogLog> khlls = new ObjectBigArray<>();
private long groupId;
private long size;
private final long groupLimit;

public GroupedKHyperLogLogState(long groupLimit)
{
this.groupLimit = groupLimit;
}

@Override
public void setGroupId(long groupId)
Expand All @@ -65,6 +82,9 @@ public void setGroupId(long groupId)
@Override
public void ensureCapacity(long size)
{
if (groupLimit > 0 && size > groupLimit) {
throw new PrestoException(NOT_SUPPORTED, format("GroupedKHyperLogLogState number of groups exceed limit %d set by khyperloglog-agg-group-limit", groupLimit));
}
khlls.ensureCapacity(size);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ public void testDefaults()
.setRemoveRedundantCastToVarcharInJoin(true)
.setHandleComplexEquiJoins(false)
.setSkipHashGenerationForJoinWithTableScanInput(false)
.setCteMaterializationStrategy(CteMaterializationStrategy.NONE));
.setCteMaterializationStrategy(CteMaterializationStrategy.NONE)
.setKHyperLogLogAggregationGroupNumberLimit(0));
}

@Test
Expand Down Expand Up @@ -463,6 +464,7 @@ public void testExplicitPropertyMappings()
.put("cte-materialization-strategy", "ALL")
.put("optimizer.handle-complex-equi-joins", "true")
.put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true")
.put("khyperloglog-agg-group-limit", "1000")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -664,7 +666,8 @@ public void testExplicitPropertyMappings()
.setRemoveRedundantCastToVarcharInJoin(false)
.setHandleComplexEquiJoins(true)
.setSkipHashGenerationForJoinWithTableScanInput(true)
.setCteMaterializationStrategy(CteMaterializationStrategy.ALL);
.setCteMaterializationStrategy(CteMaterializationStrategy.ALL)
.setKHyperLogLogAggregationGroupNumberLimit(1000);
assertFullMapping(properties, expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
Expand All @@ -42,7 +41,7 @@
public class TestKHyperLogLogAggregationFunction
{
private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
private static final String NAME = KHyperLogLogAggregationFunction.class.getAnnotation(AggregationFunction.class).value();
private static final String NAME = KHyperLogLogAggregationFunction.getFunctionName();

@Test
public void testSimpleKHyperLogLog()
Expand Down