From 6f22d50fd6c8daecbc4159be282773513bb6b8d3 Mon Sep 17 00:00:00 2001 From: Zac Blanco Date: Mon, 6 Nov 2023 12:05:11 -0500 Subject: [PATCH] Allow Type in agg function metadata constructors This changes makes some improvements to the annotation parsing for aggregation functions. One of the main downsides to the previous implementation is that parametric aggregations with defined type parameters couldn't be easily specialized as their serialization and deserialization functions usually relied upon knowing the type parameters. Thus, many of the parametric aggregation functions presto codebase don't use the @AggregationFunction and associated annotations which makes the code more complex and hard to maintain. We introduce the ability to allow the AccumulatorStateMetadata classes, AccumulatorStateSerializer and AccumulatorStateFactory, to add `Type` parameters to their constructor with proper @TypeParameter annotations. This should allow any new parametric aggregation functions to be implemented using annotations rather than manually extending BuiltInSqlFunction --- .../aggregation/ParametricAggregation.java | 4 +- .../aggregation/state/StateCompiler.java | 35 +++++- .../aggregation/TestStateCompiler.java | 101 ++++++++++++++++++ .../aggregation/AggregationMetadata.java | 2 +- 4 files changed, 135 insertions(+), 7 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java index f20d3374ebdc7..c61b9b69d46d5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java @@ -85,7 +85,7 @@ public BuiltInAggregationFunctionImplementation specialize(BoundVariables variab // Build state factory and serializer Class stateClass = concreteImplementation.getStateClass(); AccumulatorStateSerializer stateSerializer = getAccumulatorStateSerializer(concreteImplementation, variables, functionAndTypeManager, stateClass, classLoader); - AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader); + AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateClass, variables.getTypeVariables(), classLoader); // Bind provided dependencies to aggregation method handlers MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), variables, functionAndTypeManager); @@ -188,7 +188,7 @@ private static AccumulatorStateSerializer getAccumulatorStateSerializer(Aggre } } else { - stateSerializer = generateStateSerializer(stateClass, classLoader); + stateSerializer = generateStateSerializer(stateClass, variables.getTypeVariables(), classLoader); } return stateSerializer; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java index 04e3889ecdb4a..cd0fb9088cce8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java @@ -38,6 +38,7 @@ import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; import com.google.common.collect.ImmutableList; @@ -48,10 +49,12 @@ import org.openjdk.jol.info.ClassLayout; import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -146,9 +149,9 @@ public static AccumulatorStateSerializer generateStateSerializer(Class AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); if (metadata != null && metadata.stateSerializerClass() != void.class) { try { - return (AccumulatorStateSerializer) metadata.stateSerializerClass().getConstructor().newInstance(); + return getAccumulatorStateMetadataInstance(metadata.stateSerializerClass(), fieldTypes); } - catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { throw new RuntimeException(e); } } @@ -178,6 +181,30 @@ public static AccumulatorStateSerializer generateStateSerializer(Class } } + private static T getAccumulatorStateMetadataInstance(Class clazz, Map fieldTypes) + throws InvocationTargetException, InstantiationException, IllegalAccessException + { + Optional> constructor = Arrays.stream(clazz.getConstructors()) + .filter(cons -> Modifier.isPublic(cons.getModifiers())) + .filter(cons -> Arrays.stream(cons.getParameters()).allMatch(param -> + Type.class.equals(param.getType()) && + // parameter must have valid @TypeParameter annotation + param.isAnnotationPresent(TypeParameter.class) && + fieldTypes.containsKey(param.getAnnotation(TypeParameter.class).value()))) + // this will only run with n > 1 values left + .reduce((first, second) -> { + throw new IllegalArgumentException("Multiple ambiguous annotated constructors in " + clazz + ". Only one valid constructor is allowed."); + }); + if (!constructor.isPresent()) { + throw new IllegalArgumentException("Unable to find a suitable constructor for accumulator metadata class " + clazz); + } + Constructor cons = constructor.get(); + Object[] params = Arrays.stream(cons.getParameters()) + .map(param -> fieldTypes.get(param.getAnnotation(TypeParameter.class).value())) + .toArray(); + return (T) cons.newInstance(params); + } + private static void generateGetSerializedType(ClassDefinition definition, List fields, CallSiteBinder callSiteBinder) { BytecodeBlock body = definition.declareMethod(a(PUBLIC), "getSerializedType", type(Type.class)).getBody(); @@ -351,9 +378,9 @@ public static AccumulatorStateFactory generateStateFactory(Class clazz AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); if (metadata != null && metadata.stateFactoryClass() != void.class) { try { - return (AccumulatorStateFactory) metadata.stateFactoryClass().getConstructor().newInstance(); + return getAccumulatorStateMetadataInstance(metadata.stateFactoryClass(), fieldTypes); } - catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { throw new RuntimeException(e); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java index dbea2c6f72aeb..2060226e61405 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java @@ -33,8 +33,10 @@ import com.facebook.presto.operator.aggregation.state.VarianceState; import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AccumulatorStateFactory; +import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.AccumulatorStateSerializer; import com.facebook.presto.spi.function.GroupedAccumulatorState; +import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.util.Reflection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -59,6 +61,8 @@ import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedDoubleArray; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; public class TestStateCompiler { @@ -361,6 +365,30 @@ public void testComplexStateEstimatedSize() } } + @Test + public void testStateSerializerConstructorsWithMetadata() + { + Map fields = ImmutableMap.of("T", BIGINT, "E", VARCHAR); + Object stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerNoType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerNoType.class.getClassLoader())); + assertTrue(stateSerializer instanceof TestAccumulatorSerializerNoType); + + stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerMultipleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerMultipleType.class.getClassLoader())); + assertTrue(stateSerializer instanceof TestAccumulatorSerializerMultipleType); + + stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerSingleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerSingleType.class.getClassLoader())); + assertTrue(stateSerializer instanceof TestAccumulatorSerializerSingleType); + + assertThrows(() -> StateCompiler.generateStateSerializer( + TestAccumulatorSerializerUntyped.class, + fields, + new DynamicClassLoader(TestAccumulatorSerializerUntyped.class.getClassLoader()))); + + assertThrows(() -> StateCompiler.generateStateSerializer( + TestAccumulatorAmbiguousConstructor.class, + fields, + new DynamicClassLoader(TestAccumulatorAmbiguousConstructor.class.getClassLoader()))); + } + public interface TestComplexState extends AccumulatorState { @@ -428,4 +456,77 @@ public interface SliceState void setSlice(Slice slice); } + + private abstract static class TestAccumulatorSerializer + implements AccumulatorStateSerializer + { + @Override + public Type getSerializedType() + { + return null; + } + + @Override + public void serialize(Object state, BlockBuilder out) + {} + + @Override + public void deserialize(Block block, int index, Object state) + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerSingleType.class) + public static class TestAccumulatorSerializerSingleType + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerSingleType(@TypeParameter("E") Type first) + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerMultipleType.class) + public static class TestAccumulatorSerializerMultipleType + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerMultipleType(@TypeParameter("E") Type first, @TypeParameter("T") Type second) + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerNoType.class) + public static class TestAccumulatorSerializerNoType + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerNoType() + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorAmbiguousConstructor.class) + public static class TestAccumulatorAmbiguousConstructor + extends TestAccumulatorSerializer + { + public TestAccumulatorAmbiguousConstructor() + {} + + public TestAccumulatorAmbiguousConstructor(@TypeParameter("E") Type type) + {} + } + + // test all invalid constructor types + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerUntyped.class) + public static class TestAccumulatorSerializerUntyped + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerUntyped(Type x) + {} + + public TestAccumulatorSerializerUntyped(int y) + {} + + // type parameter G should not be in passed fields + public TestAccumulatorSerializerUntyped(@TypeParameter("G") Object y) + {} + + // type parameter G should not be in passed fields + public TestAccumulatorSerializerUntyped(@TypeParameter("E") Type x, @TypeParameter("G") Long y) + {} + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java index 17064e0b9545f..cd00af12e25d4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java @@ -202,7 +202,7 @@ private static void verifyExactOutputFunction(MethodHandle method, List[] parameterTypes = method.type().parameterArray(); - checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for combine function must be exactly one plus than number of states."); + checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for output function must be exactly one plus than number of states."); for (int i = 0; i < stateDescriptors.size(); i++) { checkArgument(parameterTypes[i].equals(stateDescriptors.get(i).getStateInterface()), format("Type for Parameter index %d is unexpected", i)); }