Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,35 @@
import io.trino.metadata.Signature;
import io.trino.operator.ParametricImplementationsGroup;
import io.trino.operator.annotations.FunctionsParserHelper;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.FunctionDependency;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameter;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.RemoveInputFunction;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.TypeSignature;

import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation;
import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;

public final class AggregationFromAnnotationsParser
Expand Down Expand Up @@ -174,38 +182,70 @@ private static List<String> getAliases(AggregationFunction aggregationAnnotation

private static Optional<Method> getCombineFunction(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class).stream()
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 0)] == stateClass)
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 1)] == stateClass)
.collect(toImmutableList());

List<Method> combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class);
for (Method combineFunction : combineFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(combineFunction);
List<Class<?>> expectedParameterTypes = nCopies(2, stateClass);
checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction);
}
checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateClass.toGenericString());
return combineFunctions.stream().findFirst();
}

private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class).stream()
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass)
.collect(toImmutableList());

List<Method> outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class);
for (Method outputFunction : outputFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(outputFunction);
List<Class<?>> expectedParameterTypes = ImmutableList.<Class<?>>builder()
.add(stateClass)
.add(BlockBuilder.class)
.build();
checkArgument(parameterTypes.equals(expectedParameterTypes),
"Expected output function non-dependency parameters to be %s: %s",
expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()),
outputFunction);
}
checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions");
return outputFunctions;
}

private static List<Method> getInputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class).stream()
.filter(method -> (method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass))
.collect(toImmutableList());
List<Method> inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class);
for (Method inputFunction : inputFunctions) {
// verify state parameter is first non-dependency parameter
Class<?> actualStateType = getNonDependencyParameterTypes(inputFunction).get(0);
checkArgument(stateClass.equals(actualStateType),
"Expected input function non-dependency parameters to begin with state type %s: %s",
stateClass.getSimpleName(),
inputFunction);
}

checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions");
return inputFunctions;
}

private static IntStream getNonDependencyParameters(Method function)
{
Annotation[][] parameterAnnotations = function.getParameterAnnotations();
return IntStream.range(0, function.getParameterCount())
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(TypeParameter.class::isInstance))
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(LiteralParameter.class::isInstance))
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(OperatorDependency.class::isInstance))
.filter(i -> Arrays.stream(parameterAnnotations[i]).noneMatch(FunctionDependency.class::isInstance));
}

private static List<Class<?>> getNonDependencyParameterTypes(Method function)
{
Class<?>[] parameterTypes = function.getParameterTypes();
return getNonDependencyParameters(function)
.mapToObj(index -> parameterTypes[index])
.collect(toImmutableList());
}

private static Optional<Method> getRemoveInputFunction(Class<?> clazz, Method inputFunction)
{
// Only include methods which take the same parameters as the corresponding input function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static java.lang.invoke.MethodType.methodType;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
Expand Down Expand Up @@ -146,9 +147,9 @@ public void testSimpleExactAggregationParse()
aggregation.specialize(boundSignature, NO_FUNCTION_DEPENDENCIES);
}

@AggregationFunction("simple_exact_aggregate_aggregation_state_moved")
@Description("Simple exact function which has @AggregationState on different than first positions")
public static final class StateOnDifferentThanFirstPositionAggregationFunction
@AggregationFunction("input_parameters_wrong_order")
@Description("AggregationState must be the first input parameter")
public static final class InputParametersWrongOrder
{
@InputFunction
public static void input(@SqlType(DOUBLE) double value, @AggregationState NullableDoubleState state)
Expand All @@ -163,27 +164,49 @@ public static void combine(@AggregationState NullableDoubleState combine1, @Aggr
}

@OutputFunction(DOUBLE)
public static void output(BlockBuilder out, @AggregationState NullableDoubleState state)
public static void output(@AggregationState NullableDoubleState state, BlockBuilder out)
{
// noop this is only for annotation testing puproses
}
}

@Test
public void testStateOnDifferentThanFirstPositionAggregationParse()
public void testInputParameterOrderEnforced()
{
Signature expectedSignature = Signature.builder()
.name("simple_exact_aggregate_aggregation_state_moved")
.returnType(DoubleType.DOUBLE)
.argumentType(DoubleType.DOUBLE)
.build();
assertThatThrownBy(() -> parseFunctionDefinitions(InputParametersWrongOrder.class))
.hasMessage("Expected input function non-dependency parameters to begin with state type NullableDoubleState: " +
"public static void io.trino.operator.TestAnnotationEngineForAggregates$InputParametersWrongOrder.input(double,io.trino.operator.aggregation.state.NullableDoubleState)");
}

ParametricAggregation aggregation = getOnlyElement(parseFunctionDefinitions(StateOnDifferentThanFirstPositionAggregationFunction.class));
assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
@AggregationFunction("output_parameters_wrong_order")
@Description("AggregationState must be the first output parameter")
public static final class OutputParametersWrongOrder
{
@InputFunction
public static void input(@AggregationState NullableDoubleState state, @SqlType(DOUBLE) double value)
{
// noop this is only for annotation testing puproses
}

AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values());
assertEquals(implementation.getDefinitionClass(), StateOnDifferentThanFirstPositionAggregationFunction.class);
assertEquals(implementation.getInputParameterKinds(), ImmutableList.of(INPUT_CHANNEL, STATE));
@CombineFunction
public static void combine(@AggregationState NullableDoubleState combine1, @AggregationState NullableDoubleState combine2)
{
// noop this is only for annotation testing puproses
}

@OutputFunction(DOUBLE)
public static void output(BlockBuilder out, @AggregationState NullableDoubleState state)
{
// noop this is only for annotation testing puproses
}
}

@Test
public void testOutputParameterOrderEnforced()
{
assertThatThrownBy(() -> parseFunctionDefinitions(OutputParametersWrongOrder.class))
.hasMessage("Expected output function non-dependency parameters to be [NullableDoubleState, BlockBuilder]: " +
"public static void io.trino.operator.TestAnnotationEngineForAggregates$OutputParametersWrongOrder.output(io.trino.spi.block.BlockBuilder,io.trino.operator.aggregation.state.NullableDoubleState)");
}

@AggregationFunction("no_aggregation_state_aggregate")
Expand Down