diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java index f20d3374ebdc7..c61b9b69d46d5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ParametricAggregation.java @@ -85,7 +85,7 @@ public BuiltInAggregationFunctionImplementation specialize(BoundVariables variab // Build state factory and serializer Class stateClass = concreteImplementation.getStateClass(); AccumulatorStateSerializer stateSerializer = getAccumulatorStateSerializer(concreteImplementation, variables, functionAndTypeManager, stateClass, classLoader); - AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader); + AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateClass, variables.getTypeVariables(), classLoader); // Bind provided dependencies to aggregation method handlers MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), variables, functionAndTypeManager); @@ -188,7 +188,7 @@ private static AccumulatorStateSerializer getAccumulatorStateSerializer(Aggre } } else { - stateSerializer = generateStateSerializer(stateClass, classLoader); + stateSerializer = generateStateSerializer(stateClass, variables.getTypeVariables(), classLoader); } return stateSerializer; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java index 04e3889ecdb4a..cd0fb9088cce8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java @@ -38,6 +38,7 @@ import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; import com.google.common.collect.ImmutableList; @@ -48,10 +49,12 @@ import org.openjdk.jol.info.ClassLayout; import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -146,9 +149,9 @@ public static AccumulatorStateSerializer generateStateSerializer(Class AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); if (metadata != null && metadata.stateSerializerClass() != void.class) { try { - return (AccumulatorStateSerializer) metadata.stateSerializerClass().getConstructor().newInstance(); + return getAccumulatorStateMetadataInstance(metadata.stateSerializerClass(), fieldTypes); } - catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { throw new RuntimeException(e); } } @@ -178,6 +181,30 @@ public static AccumulatorStateSerializer generateStateSerializer(Class } } + private static T getAccumulatorStateMetadataInstance(Class clazz, Map fieldTypes) + throws InvocationTargetException, InstantiationException, IllegalAccessException + { + Optional> constructor = Arrays.stream(clazz.getConstructors()) + .filter(cons -> Modifier.isPublic(cons.getModifiers())) + .filter(cons -> Arrays.stream(cons.getParameters()).allMatch(param -> + Type.class.equals(param.getType()) && + // parameter must have valid @TypeParameter annotation + param.isAnnotationPresent(TypeParameter.class) && + fieldTypes.containsKey(param.getAnnotation(TypeParameter.class).value()))) + // this will only run with n > 1 values left + .reduce((first, second) -> { + throw new IllegalArgumentException("Multiple ambiguous annotated constructors in " + clazz + ". Only one valid constructor is allowed."); + }); + if (!constructor.isPresent()) { + throw new IllegalArgumentException("Unable to find a suitable constructor for accumulator metadata class " + clazz); + } + Constructor cons = constructor.get(); + Object[] params = Arrays.stream(cons.getParameters()) + .map(param -> fieldTypes.get(param.getAnnotation(TypeParameter.class).value())) + .toArray(); + return (T) cons.newInstance(params); + } + private static void generateGetSerializedType(ClassDefinition definition, List fields, CallSiteBinder callSiteBinder) { BytecodeBlock body = definition.declareMethod(a(PUBLIC), "getSerializedType", type(Type.class)).getBody(); @@ -351,9 +378,9 @@ public static AccumulatorStateFactory generateStateFactory(Class clazz AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); if (metadata != null && metadata.stateFactoryClass() != void.class) { try { - return (AccumulatorStateFactory) metadata.stateFactoryClass().getConstructor().newInstance(); + return getAccumulatorStateMetadataInstance(metadata.stateFactoryClass(), fieldTypes); } - catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { throw new RuntimeException(e); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java index dbea2c6f72aeb..2060226e61405 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java @@ -33,8 +33,10 @@ import com.facebook.presto.operator.aggregation.state.VarianceState; import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AccumulatorStateFactory; +import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.AccumulatorStateSerializer; import com.facebook.presto.spi.function.GroupedAccumulatorState; +import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.util.Reflection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -59,6 +61,8 @@ import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedDoubleArray; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; public class TestStateCompiler { @@ -361,6 +365,30 @@ public void testComplexStateEstimatedSize() } } + @Test + public void testStateSerializerConstructorsWithMetadata() + { + Map fields = ImmutableMap.of("T", BIGINT, "E", VARCHAR); + Object stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerNoType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerNoType.class.getClassLoader())); + assertTrue(stateSerializer instanceof TestAccumulatorSerializerNoType); + + stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerMultipleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerMultipleType.class.getClassLoader())); + assertTrue(stateSerializer instanceof TestAccumulatorSerializerMultipleType); + + stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerSingleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerSingleType.class.getClassLoader())); + assertTrue(stateSerializer instanceof TestAccumulatorSerializerSingleType); + + assertThrows(() -> StateCompiler.generateStateSerializer( + TestAccumulatorSerializerUntyped.class, + fields, + new DynamicClassLoader(TestAccumulatorSerializerUntyped.class.getClassLoader()))); + + assertThrows(() -> StateCompiler.generateStateSerializer( + TestAccumulatorAmbiguousConstructor.class, + fields, + new DynamicClassLoader(TestAccumulatorAmbiguousConstructor.class.getClassLoader()))); + } + public interface TestComplexState extends AccumulatorState { @@ -428,4 +456,77 @@ public interface SliceState void setSlice(Slice slice); } + + private abstract static class TestAccumulatorSerializer + implements AccumulatorStateSerializer + { + @Override + public Type getSerializedType() + { + return null; + } + + @Override + public void serialize(Object state, BlockBuilder out) + {} + + @Override + public void deserialize(Block block, int index, Object state) + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerSingleType.class) + public static class TestAccumulatorSerializerSingleType + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerSingleType(@TypeParameter("E") Type first) + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerMultipleType.class) + public static class TestAccumulatorSerializerMultipleType + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerMultipleType(@TypeParameter("E") Type first, @TypeParameter("T") Type second) + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerNoType.class) + public static class TestAccumulatorSerializerNoType + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerNoType() + {} + } + + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorAmbiguousConstructor.class) + public static class TestAccumulatorAmbiguousConstructor + extends TestAccumulatorSerializer + { + public TestAccumulatorAmbiguousConstructor() + {} + + public TestAccumulatorAmbiguousConstructor(@TypeParameter("E") Type type) + {} + } + + // test all invalid constructor types + @AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerUntyped.class) + public static class TestAccumulatorSerializerUntyped + extends TestAccumulatorSerializer + { + public TestAccumulatorSerializerUntyped(Type x) + {} + + public TestAccumulatorSerializerUntyped(int y) + {} + + // type parameter G should not be in passed fields + public TestAccumulatorSerializerUntyped(@TypeParameter("G") Object y) + {} + + // type parameter G should not be in passed fields + public TestAccumulatorSerializerUntyped(@TypeParameter("E") Type x, @TypeParameter("G") Long y) + {} + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java index 17064e0b9545f..cd00af12e25d4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java @@ -202,7 +202,7 @@ private static void verifyExactOutputFunction(MethodHandle method, List[] parameterTypes = method.type().parameterArray(); - checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for combine function must be exactly one plus than number of states."); + checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for output function must be exactly one plus than number of states."); for (int i = 0; i < stateDescriptors.size(); i++) { checkArgument(parameterTypes[i].equals(stateDescriptors.get(i).getStateInterface()), format("Type for Parameter index %d is unexpected", i)); }