diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 70f4e6074d04..389a00f32752 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -20,20 +20,27 @@ import io.trino.metadata.Signature; import io.trino.operator.ParametricImplementationsGroup; import io.trino.operator.annotations.FunctionsParserHelper; +import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.FunctionDependency; import io.trino.spi.function.InputFunction; +import io.trino.spi.function.LiteralParameter; +import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.OutputFunction; import io.trino.spi.function.RemoveInputFunction; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.TypeSignature; +import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.emptyToNull; @@ -41,6 +48,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation; import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription; +import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public final class AggregationFromAnnotationsParser @@ -174,38 +182,70 @@ private static List getAliases(AggregationFunction aggregationAnnotation private static Optional getCombineFunction(Class clazz, Class stateClass) { - // Only include methods that match this state class - List combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class).stream() - .filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 0)] == stateClass) - .filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 1)] == stateClass) - .collect(toImmutableList()); - + List combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class); + for (Method combineFunction : combineFunctions) { + // verify parameter types + List> parameterTypes = getNonDependencyParameterTypes(combineFunction); + List> expectedParameterTypes = nCopies(2, stateClass); + checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction); + } checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateClass.toGenericString()); return combineFunctions.stream().findFirst(); } private static List getOutputFunctions(Class clazz, Class stateClass) { - // Only include methods that match this state class - List outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class).stream() - .filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass) - .collect(toImmutableList()); - + List outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class); + for (Method outputFunction : outputFunctions) { + // verify parameter types + List> parameterTypes = getNonDependencyParameterTypes(outputFunction); + List> expectedParameterTypes = ImmutableList.>builder() + .add(stateClass) + .add(BlockBuilder.class) + .build(); + checkArgument(parameterTypes.equals(expectedParameterTypes), + "Expected output function non-dependency parameters to be %s: %s", + expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()), + outputFunction); + } checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions"); return outputFunctions; } private static List getInputFunctions(Class clazz, Class stateClass) { - // Only include methods that match this state class - List inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class).stream() - .filter(method -> (method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass)) - .collect(toImmutableList()); + List inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class); + for (Method inputFunction : inputFunctions) { + // verify state parameter is first non-dependency parameter + Class actualStateType = getNonDependencyParameterTypes(inputFunction).get(0); + checkArgument(stateClass.equals(actualStateType), + "Expected input function non-dependency parameters to begin with state type %s: %s", + stateClass.getSimpleName(), + inputFunction); + } checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions"); return inputFunctions; } + private static IntStream getNonDependencyParameters(Method function) + { + Annotation[][] parameterAnnotations = function.getParameterAnnotations(); + return IntStream.range(0, function.getParameterCount()) + .filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(TypeParameter.class::isInstance)) + .filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(LiteralParameter.class::isInstance)) + .filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(OperatorDependency.class::isInstance)) + .filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(FunctionDependency.class::isInstance)); + } + + private static List> getNonDependencyParameterTypes(Method function) + { + Class[] parameterTypes = function.getParameterTypes(); + return getNonDependencyParameters(function) + .mapToObj(index -> parameterTypes[index]) + .collect(toImmutableList()); + } + private static Optional getRemoveInputFunction(Class clazz, Method inputFunction) { // Only include methods which take the same parameters as the corresponding input function diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java index 4818a59ccc2e..333f8f2b2e3f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java @@ -84,6 +84,7 @@ import static io.trino.spi.type.TypeSignatureParameter.typeVariable; import static io.trino.spi.type.VarcharType.createVarcharType; import static java.lang.invoke.MethodType.methodType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -146,9 +147,9 @@ public void testSimpleExactAggregationParse() aggregation.specialize(boundSignature, NO_FUNCTION_DEPENDENCIES); } - @AggregationFunction("simple_exact_aggregate_aggregation_state_moved") - @Description("Simple exact function which has @AggregationState on different than first positions") - public static final class StateOnDifferentThanFirstPositionAggregationFunction + @AggregationFunction("input_parameters_wrong_order") + @Description("AggregationState must be the first input parameter") + public static final class InputParametersWrongOrder { @InputFunction public static void input(@SqlType(DOUBLE) double value, @AggregationState NullableDoubleState state) @@ -163,27 +164,49 @@ public static void combine(@AggregationState NullableDoubleState combine1, @Aggr } @OutputFunction(DOUBLE) - public static void output(BlockBuilder out, @AggregationState NullableDoubleState state) + public static void output(@AggregationState NullableDoubleState state, BlockBuilder out) { // noop this is only for annotation testing puproses } } @Test - public void testStateOnDifferentThanFirstPositionAggregationParse() + public void testInputParameterOrderEnforced() { - Signature expectedSignature = Signature.builder() - .name("simple_exact_aggregate_aggregation_state_moved") - .returnType(DoubleType.DOUBLE) - .argumentType(DoubleType.DOUBLE) - .build(); + assertThatThrownBy(() -> parseFunctionDefinitions(InputParametersWrongOrder.class)) + .hasMessage("Expected input function non-dependency parameters to begin with state type NullableDoubleState: " + + "public static void io.trino.operator.TestAnnotationEngineForAggregates$InputParametersWrongOrder.input(double,io.trino.operator.aggregation.state.NullableDoubleState)"); + } - ParametricAggregation aggregation = getOnlyElement(parseFunctionDefinitions(StateOnDifferentThanFirstPositionAggregationFunction.class)); - assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); + @AggregationFunction("output_parameters_wrong_order") + @Description("AggregationState must be the first output parameter") + public static final class OutputParametersWrongOrder + { + @InputFunction + public static void input(@AggregationState NullableDoubleState state, @SqlType(DOUBLE) double value) + { + // noop this is only for annotation testing puproses + } - AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); - assertEquals(implementation.getDefinitionClass(), StateOnDifferentThanFirstPositionAggregationFunction.class); - assertEquals(implementation.getInputParameterKinds(), ImmutableList.of(INPUT_CHANNEL, STATE)); + @CombineFunction + public static void combine(@AggregationState NullableDoubleState combine1, @AggregationState NullableDoubleState combine2) + { + // noop this is only for annotation testing puproses + } + + @OutputFunction(DOUBLE) + public static void output(BlockBuilder out, @AggregationState NullableDoubleState state) + { + // noop this is only for annotation testing puproses + } + } + + @Test + public void testOutputParameterOrderEnforced() + { + assertThatThrownBy(() -> parseFunctionDefinitions(OutputParametersWrongOrder.class)) + .hasMessage("Expected output function non-dependency parameters to be [NullableDoubleState, BlockBuilder]: " + + "public static void io.trino.operator.TestAnnotationEngineForAggregates$OutputParametersWrongOrder.output(io.trino.spi.block.BlockBuilder,io.trino.operator.aggregation.state.NullableDoubleState)"); } @AggregationFunction("no_aggregation_state_aggregate")