diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java index 4a7500044965..8c7cb11f60c6 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java @@ -22,6 +22,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.InOut; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.type.Type; @@ -196,7 +197,10 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, F break; case BLOCK_POSITION: verifyFunctionSignature(parameterType.equals(Block.class) && methodType.parameterType(parameterIndex + 1).equals(int.class), - "Expected BLOCK_POSITION argument have parameters Block and int"); + "Expected BLOCK_POSITION argument types to be Block and int"); + break; + case IN_OUT: + verifyFunctionSignature(parameterType.equals(InOut.class), "Expected IN_OUT argument type to be InOut"); break; case FUNCTION: Class lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex); diff --git a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java index f44b5433441b..e18e9d321210 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java @@ -129,6 +129,11 @@ private static boolean matchesParameterAndReturnTypes( } actualType = resolvedType.getJavaType(); break; + case IN_OUT: + // any type is supported, so just ignore this check + actualType = resolvedType.getJavaType(); + expectedType = resolvedType.getJavaType(); + break; default: throw new UnsupportedOperationException("Unknown argument convention: " + argumentConvention); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java b/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java index 4e8e51a5e105..ab8deb73b9b2 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java @@ -29,7 +29,12 @@ public abstract class SqlAggregationFunction public static List createFunctionsByAnnotations(Class aggregationDefinition) { - return ImmutableList.copyOf(AggregationFromAnnotationsParser.parseFunctionDefinitions(aggregationDefinition)); + try { + return ImmutableList.copyOf(AggregationFromAnnotationsParser.parseFunctionDefinitions(aggregationDefinition)); + } + catch (RuntimeException e) { + throw new IllegalArgumentException("Invalid aggregation class " + aggregationDefinition.getSimpleName()); + } } public SqlAggregationFunction(FunctionMetadata functionMetadata, AggregationFunctionMetadata aggregationFunctionMetadata) diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java index d3e443bfafd0..bcec65ba4aee 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java @@ -25,6 +25,7 @@ import io.trino.operator.aggregation.ApproximateRealPercentileArrayAggregations; import io.trino.operator.aggregation.ApproximateSetAggregation; import io.trino.operator.aggregation.ApproximateSetGenericAggregation; +import io.trino.operator.aggregation.ArbitraryAggregationFunction; import io.trino.operator.aggregation.AverageAggregations; import io.trino.operator.aggregation.BigintApproximateMostFrequent; import io.trino.operator.aggregation.BitwiseAndAggregation; @@ -36,7 +37,10 @@ import io.trino.operator.aggregation.CentralMomentsAggregation; import io.trino.operator.aggregation.ChecksumAggregationFunction; import io.trino.operator.aggregation.CountAggregation; +import io.trino.operator.aggregation.CountColumn; import io.trino.operator.aggregation.CountIfAggregation; +import io.trino.operator.aggregation.DecimalAverageAggregation; +import io.trino.operator.aggregation.DecimalSumAggregation; import io.trino.operator.aggregation.DefaultApproximateCountDistinctAggregation; import io.trino.operator.aggregation.DoubleCorrelationAggregation; import io.trino.operator.aggregation.DoubleCovarianceAggregation; @@ -54,12 +58,18 @@ import io.trino.operator.aggregation.LongSumAggregation; import io.trino.operator.aggregation.MapAggregationFunction; import io.trino.operator.aggregation.MapUnionAggregation; +import io.trino.operator.aggregation.MaxAggregationFunction; +import io.trino.operator.aggregation.MaxByAggregationFunction; import io.trino.operator.aggregation.MaxDataSizeForStats; -import io.trino.operator.aggregation.MaxNAggregationFunction; import io.trino.operator.aggregation.MergeHyperLogLogAggregation; import io.trino.operator.aggregation.MergeQuantileDigestFunction; import io.trino.operator.aggregation.MergeTDigestAggregation; -import io.trino.operator.aggregation.MinNAggregationFunction; +import io.trino.operator.aggregation.MinAggregationFunction; +import io.trino.operator.aggregation.MinByAggregationFunction; +import io.trino.operator.aggregation.QuantileDigestAggregationFunction.BigintQuantileDigestAggregationFunction; +import io.trino.operator.aggregation.QuantileDigestAggregationFunction.DoubleQuantileDigestAggregationFunction; +import io.trino.operator.aggregation.QuantileDigestAggregationFunction.RealQuantileDigestAggregationFunction; +import io.trino.operator.aggregation.RealAverageAggregation; import io.trino.operator.aggregation.RealCorrelationAggregation; import io.trino.operator.aggregation.RealCovarianceAggregation; import io.trino.operator.aggregation.RealGeometricMeanAggregations; @@ -70,9 +80,13 @@ import io.trino.operator.aggregation.TDigestAggregationFunction; import io.trino.operator.aggregation.VarcharApproximateMostFrequent; import io.trino.operator.aggregation.VarianceAggregation; +import io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction; import io.trino.operator.aggregation.histogram.Histogram; -import io.trino.operator.aggregation.minmaxby.MaxByNAggregationFunction; -import io.trino.operator.aggregation.minmaxby.MinByNAggregationFunction; +import io.trino.operator.aggregation.listagg.ListaggAggregationFunction; +import io.trino.operator.aggregation.minmaxbyn.MaxByNAggregationFunction; +import io.trino.operator.aggregation.minmaxbyn.MinByNAggregationFunction; +import io.trino.operator.aggregation.minmaxn.MaxNAggregationFunction; +import io.trino.operator.aggregation.minmaxn.MinNAggregationFunction; import io.trino.operator.aggregation.multimapagg.MultimapAggregationFunction; import io.trino.operator.scalar.ArrayAllMatchFunction; import io.trino.operator.scalar.ArrayAnyMatchFunction; @@ -255,21 +269,7 @@ import io.trino.type.setdigest.SetDigestFunctions; import io.trino.type.setdigest.SetDigestOperators; -import static io.trino.operator.aggregation.ArbitraryAggregationFunction.ARBITRARY_AGGREGATION; -import static io.trino.operator.aggregation.CountColumn.COUNT_COLUMN; -import static io.trino.operator.aggregation.DecimalAverageAggregation.DECIMAL_AVERAGE_AGGREGATION; -import static io.trino.operator.aggregation.DecimalSumAggregation.DECIMAL_SUM_AGGREGATION; -import static io.trino.operator.aggregation.MaxAggregationFunction.MAX_AGGREGATION; -import static io.trino.operator.aggregation.MinAggregationFunction.MIN_AGGREGATION; -import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG; -import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG_WITH_WEIGHT; -import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG_WITH_WEIGHT_AND_ERROR; -import static io.trino.operator.aggregation.RealAverageAggregation.REAL_AVERAGE_AGGREGATION; import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG; -import static io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction.ARRAY_AGG; -import static io.trino.operator.aggregation.listagg.ListaggAggregationFunction.LISTAGG; -import static io.trino.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY; -import static io.trino.operator.aggregation.minmaxby.MinByAggregationFunction.MIN_BY; import static io.trino.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION; import static io.trino.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR; import static io.trino.operator.scalar.ArrayFlattenFunction.ARRAY_FLATTEN_FUNCTION; @@ -391,7 +391,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .aggregates(IntervalDayToSecondSumAggregation.class) .aggregates(IntervalYearToMonthSumAggregation.class) .aggregates(AverageAggregations.class) - .function(REAL_AVERAGE_AGGREGATION) + .aggregates(RealAverageAggregation.class) .aggregates(IntervalDayToSecondAverageAggregation.class) .aggregates(IntervalYearToMonthAverageAggregation.class) .aggregates(GeometricMeanAggregations.class) @@ -400,8 +400,10 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .aggregates(ApproximateSetAggregation.class) .aggregates(ApproximateSetGenericAggregation.class) .aggregates(TDigestAggregationFunction.class) - .functions(QDIGEST_AGG, QDIGEST_AGG_WITH_WEIGHT, QDIGEST_AGG_WITH_WEIGHT_AND_ERROR) - .function(MergeQuantileDigestFunction.MERGE) + .aggregates(DoubleQuantileDigestAggregationFunction.class) + .aggregates(RealQuantileDigestAggregationFunction.class) + .aggregates(BigintQuantileDigestAggregationFunction.class) + .aggregates(MergeQuantileDigestFunction.class) .aggregates(MergeTDigestAggregation.class) .aggregates(DoubleHistogramAggregation.class) .aggregates(RealHistogramAggregation.class) @@ -518,13 +520,14 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .function(ARRAY_FLATTEN_FUNCTION) .function(ARRAY_CONCAT_FUNCTION) .functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, JSON_TO_ARRAY, JSON_STRING_TO_ARRAY) - .function(ARRAY_AGG) - .function(LISTAGG) + .aggregates(ArrayAggregationFunction.class) + .aggregates(ListaggAggregationFunction.class) .functions(new MapSubscriptOperator()) .functions(MAP_CONSTRUCTOR, JSON_TO_MAP, JSON_STRING_TO_MAP) - .functions(new MapAggregationFunction(blockTypeOperators), new MapUnionAggregation(blockTypeOperators)) + .aggregates(MapAggregationFunction.class) + .aggregates(MapUnionAggregation.class) .function(REDUCE_AGG) - .function(new MultimapAggregationFunction(blockTypeOperators)) + .aggregates(MultimapAggregationFunction.class) .functions(DECIMAL_TO_VARCHAR_CAST, DECIMAL_TO_INTEGER_CAST, DECIMAL_TO_BIGINT_CAST, DECIMAL_TO_DOUBLE_CAST, DECIMAL_TO_REAL_CAST, DECIMAL_TO_BOOLEAN_CAST, DECIMAL_TO_TINYINT_CAST, DECIMAL_TO_SMALLINT_CAST) .functions(VARCHAR_TO_DECIMAL_CAST, INTEGER_TO_DECIMAL_CAST, BIGINT_TO_DECIMAL_CAST, DOUBLE_TO_DECIMAL_CAST, REAL_TO_DECIMAL_CAST, BOOLEAN_TO_DECIMAL_CAST, TINYINT_TO_DECIMAL_CAST, SMALLINT_TO_DECIMAL_CAST) .functions(JSON_TO_DECIMAL_CAST, DECIMAL_TO_JSON_CAST) @@ -534,21 +537,27 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .functions(DECIMAL_TO_INTEGER_SATURATED_FLOOR_CAST, INTEGER_TO_DECIMAL_SATURATED_FLOOR_CAST) .functions(DECIMAL_TO_SMALLINT_SATURATED_FLOOR_CAST, SMALLINT_TO_DECIMAL_SATURATED_FLOOR_CAST) .functions(DECIMAL_TO_TINYINT_SATURATED_FLOOR_CAST, TINYINT_TO_DECIMAL_SATURATED_FLOOR_CAST) - .function(new Histogram(blockTypeOperators)) - .function(new ChecksumAggregationFunction(blockTypeOperators)) - .function(ARBITRARY_AGGREGATION) + .aggregates(Histogram.class) + .aggregates(ChecksumAggregationFunction.class) + .aggregates(ArbitraryAggregationFunction.class) .functions(GREATEST, LEAST) - .functions(MAX_BY, MIN_BY, new MaxByNAggregationFunction(blockTypeOperators), new MinByNAggregationFunction(blockTypeOperators)) - .functions(MAX_AGGREGATION, MIN_AGGREGATION, new MaxNAggregationFunction(blockTypeOperators), new MinNAggregationFunction(blockTypeOperators)) - .function(COUNT_COLUMN) + .aggregates(MinAggregationFunction.class) + .aggregates(MaxAggregationFunction.class) + .aggregates(MinByAggregationFunction.class) + .aggregates(MaxByAggregationFunction.class) + .aggregates(MaxNAggregationFunction.class) + .aggregates(MinNAggregationFunction.class) + .aggregates(MinByNAggregationFunction.class) + .aggregates(MaxByNAggregationFunction.class) + .aggregates(CountColumn.class) .functions(JSON_TO_ROW, JSON_STRING_TO_ROW, ROW_TO_ROW_CAST) .functions(VARCHAR_CONCAT, VARBINARY_CONCAT) .function(CONCAT_WS) .function(DECIMAL_TO_DECIMAL_CAST) .function(castVarcharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries())) .function(castCharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries())) - .function(DECIMAL_AVERAGE_AGGREGATION) - .function(DECIMAL_SUM_AGGREGATION) + .aggregates(DecimalAverageAggregation.class) + .aggregates(DecimalSumAggregation.class) .function(DECIMAL_MOD_FUNCTION) .functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION) .functions(MAP_FILTER_FUNCTION, new MapTransformKeysFunction(blockTypeOperators), MAP_TRANSFORM_VALUES_FUNCTION) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMinMaxAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMinMaxAggregationFunction.java deleted file mode 100644 index 6e20f7aa0e90..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMinMaxAggregationFunction.java +++ /dev/null @@ -1,302 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation; - -import com.google.common.collect.ImmutableList; -import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.state.BlockPositionState; -import io.trino.operator.aggregation.state.BlockPositionStateSerializer; -import io.trino.operator.aggregation.state.GenericBooleanState; -import io.trino.operator.aggregation.state.GenericBooleanStateSerializer; -import io.trino.operator.aggregation.state.GenericDoubleState; -import io.trino.operator.aggregation.state.GenericDoubleStateSerializer; -import io.trino.operator.aggregation.state.GenericLongState; -import io.trino.operator.aggregation.state.GenericLongStateSerializer; -import io.trino.operator.aggregation.state.StateCompiler; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.InvocationConvention; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; - -import java.lang.invoke.MethodHandle; -import java.util.List; -import java.util.Optional; - -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static io.trino.util.Failures.internalError; -import static io.trino.util.MinMaxCompare.getMinMaxCompare; -import static io.trino.util.MinMaxCompare.getMinMaxCompareFunctionDependencies; -import static io.trino.util.Reflection.methodHandle; - -public abstract class AbstractMinMaxAggregationFunction - extends SqlAggregationFunction -{ - private static final MethodHandle LONG_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "input", MethodHandle.class, GenericLongState.class, long.class); - private static final MethodHandle DOUBLE_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "input", MethodHandle.class, GenericDoubleState.class, double.class); - private static final MethodHandle BOOLEAN_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "input", MethodHandle.class, GenericBooleanState.class, boolean.class); - private static final MethodHandle BLOCK_POSITION_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "input", MethodHandle.class, BlockPositionState.class, Block.class, int.class); - - private static final MethodHandle LONG_OUTPUT_FUNCTION = methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class); - private static final MethodHandle DOUBLE_OUTPUT_FUNCTION = methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class); - private static final MethodHandle BOOLEAN_OUTPUT_FUNCTION = methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class); - private static final MethodHandle BLOCK_POSITION_OUTPUT_FUNCTION = methodHandle(BlockPositionState.class, "write", Type.class, BlockPositionState.class, BlockBuilder.class); - - private static final MethodHandle LONG_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, GenericLongState.class, GenericLongState.class); - private static final MethodHandle DOUBLE_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, GenericDoubleState.class, GenericDoubleState.class); - private static final MethodHandle BOOLEAN_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, GenericBooleanState.class, GenericBooleanState.class); - private static final MethodHandle BLOCK_POSITION_COMBINE_FUNCTION = methodHandle(AbstractMinMaxAggregationFunction.class, "combine", MethodHandle.class, BlockPositionState.class, BlockPositionState.class); - - private final boolean min; - - protected AbstractMinMaxAggregationFunction(String name, boolean min, String description) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(name) - .orderableTypeParameter("E") - .returnType(new TypeSignature("E")) - .argumentType(new TypeSignature("E")) - .build()) - .description(description) - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(new TypeSignature("E")) - .build()); - this.min = min; - } - - @Override - public FunctionDependencyDeclaration getFunctionDependencies() - { - return getMinMaxCompareFunctionDependencies(new TypeSignature("E"), min); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) - { - Type type = boundSignature.getArgumentTypes().get(0); - InvocationConvention invocationConvention; - if (type.getJavaType().isPrimitive()) { - invocationConvention = simpleConvention(FAIL_ON_NULL, NEVER_NULL, NEVER_NULL); - } - else { - invocationConvention = simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION); - } - - MethodHandle compareMethodHandle = getMinMaxCompare(functionDependencies, type, invocationConvention, min); - - MethodHandle inputFunction; - MethodHandle combineFunction; - MethodHandle outputFunction; - - AccumulatorStateDescriptor accumulatorStateDescriptor; - if (type.getJavaType() == long.class) { - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - GenericLongState.class, - new GenericLongStateSerializer(type), - StateCompiler.generateStateFactory(GenericLongState.class)); - inputFunction = LONG_INPUT_FUNCTION.bindTo(compareMethodHandle); - combineFunction = LONG_COMBINE_FUNCTION.bindTo(compareMethodHandle); - outputFunction = LONG_OUTPUT_FUNCTION.bindTo(type); - } - else if (type.getJavaType() == double.class) { - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - GenericDoubleState.class, - new GenericDoubleStateSerializer(type), - StateCompiler.generateStateFactory(GenericDoubleState.class)); - inputFunction = DOUBLE_INPUT_FUNCTION.bindTo(compareMethodHandle); - combineFunction = DOUBLE_COMBINE_FUNCTION.bindTo(compareMethodHandle); - outputFunction = DOUBLE_OUTPUT_FUNCTION.bindTo(type); - } - else if (type.getJavaType() == boolean.class) { - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - GenericBooleanState.class, - new GenericBooleanStateSerializer(type), - StateCompiler.generateStateFactory(GenericBooleanState.class)); - inputFunction = BOOLEAN_INPUT_FUNCTION.bindTo(compareMethodHandle); - combineFunction = BOOLEAN_COMBINE_FUNCTION.bindTo(compareMethodHandle); - outputFunction = BOOLEAN_OUTPUT_FUNCTION.bindTo(type); - } - else { - // native container type is Object - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - BlockPositionState.class, - new BlockPositionStateSerializer(type), - StateCompiler.generateStateFactory(BlockPositionState.class)); - inputFunction = BLOCK_POSITION_INPUT_FUNCTION.bindTo(compareMethodHandle); - combineFunction = BLOCK_POSITION_COMBINE_FUNCTION.bindTo(compareMethodHandle); - outputFunction = BLOCK_POSITION_OUTPUT_FUNCTION.bindTo(type); - } - - inputFunction = normalizeInputMethod(inputFunction, boundSignature, createInputParameterKinds(type)); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(combineFunction), - outputFunction, - ImmutableList.of(accumulatorStateDescriptor)); - } - - private static List createInputParameterKinds(Type type) - { - if (type.getJavaType().isPrimitive()) { - return ImmutableList.of( - STATE, - INPUT_CHANNEL); - } - else { - return ImmutableList.of( - STATE, - BLOCK_INPUT_CHANNEL, - BLOCK_INDEX); - } - } - - @UsedByGeneratedCode - public static void input(MethodHandle methodHandle, GenericDoubleState state, double value) - { - compareAndUpdateState(methodHandle, state, value); - } - - @UsedByGeneratedCode - public static void input(MethodHandle methodHandle, GenericLongState state, long value) - { - compareAndUpdateState(methodHandle, state, value); - } - - @UsedByGeneratedCode - public static void input(MethodHandle methodHandle, GenericBooleanState state, boolean value) - { - compareAndUpdateState(methodHandle, state, value); - } - - @UsedByGeneratedCode - public static void input(MethodHandle methodHandle, BlockPositionState state, Block block, int position) - { - compareAndUpdateState(methodHandle, state, block, position); - } - - @UsedByGeneratedCode - public static void combine(MethodHandle methodHandle, GenericLongState state, GenericLongState otherState) - { - compareAndUpdateState(methodHandle, state, otherState.getValue()); - } - - @UsedByGeneratedCode - public static void combine(MethodHandle methodHandle, GenericDoubleState state, GenericDoubleState otherState) - { - compareAndUpdateState(methodHandle, state, otherState.getValue()); - } - - @UsedByGeneratedCode - public static void combine(MethodHandle methodHandle, GenericBooleanState state, GenericBooleanState otherState) - { - compareAndUpdateState(methodHandle, state, otherState.getValue()); - } - - @UsedByGeneratedCode - public static void combine(MethodHandle methodHandle, BlockPositionState state, BlockPositionState otherState) - { - compareAndUpdateState(methodHandle, state, otherState.getBlock(), otherState.getPosition()); - } - - private static void compareAndUpdateState(MethodHandle methodHandle, GenericLongState state, long value) - { - if (state.isNull()) { - state.setNull(false); - state.setValue(value); - return; - } - try { - if ((boolean) methodHandle.invokeExact(value, state.getValue())) { - state.setValue(value); - } - } - catch (Throwable t) { - throw internalError(t); - } - } - - private static void compareAndUpdateState(MethodHandle methodHandle, GenericDoubleState state, double value) - { - if (state.isNull()) { - state.setNull(false); - state.setValue(value); - return; - } - try { - if ((boolean) methodHandle.invokeExact(value, state.getValue())) { - state.setValue(value); - } - } - catch (Throwable t) { - throw internalError(t); - } - } - - private static void compareAndUpdateState(MethodHandle methodHandle, GenericBooleanState state, boolean value) - { - if (state.isNull()) { - state.setNull(false); - state.setValue(value); - return; - } - try { - if ((boolean) methodHandle.invokeExact(value, state.getValue())) { - state.setValue(value); - } - } - catch (Throwable t) { - throw internalError(t); - } - } - - private static void compareAndUpdateState(MethodHandle methodHandle, BlockPositionState state, Block block, int position) - { - if (state.isNull()) { - state.setBlock(block); - state.setPosition(position); - return; - } - try { - if ((boolean) methodHandle.invokeExact(block, position, state.getBlock(), state.getPosition())) { - state.setBlock(block); - state.setPosition(position); - } - } - catch (Throwable t) { - throw internalError(t); - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMinMaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMinMaxNAggregationFunction.java deleted file mode 100644 index 82962856762a..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMinMaxNAggregationFunction.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation; - -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.state.MinMaxNState; -import io.trino.operator.aggregation.state.MinMaxNStateFactory; -import io.trino.operator.aggregation.state.MinMaxNStateSerializer; -import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.util.MinMaxCompare; - -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; -import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.TypeSignature.arrayType; -import static io.trino.util.Failures.checkCondition; -import static io.trino.util.MinMaxCompare.getMinMaxCompare; -import static io.trino.util.Reflection.methodHandle; -import static java.lang.Math.toIntExact; - -public abstract class AbstractMinMaxNAggregationFunction - extends SqlAggregationFunction -{ - private static final MethodHandle INPUT_FUNCTION = methodHandle(AbstractMinMaxNAggregationFunction.class, "input", MethodHandle.class, Type.class, MinMaxNState.class, Block.class, long.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(AbstractMinMaxNAggregationFunction.class, "combine", MinMaxNState.class, MinMaxNState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(AbstractMinMaxNAggregationFunction.class, "output", ArrayType.class, MinMaxNState.class, BlockBuilder.class); - private static final long MAX_NUMBER_OF_VALUES = 10_000; - - private final boolean min; - - protected AbstractMinMaxNAggregationFunction(String name, boolean min, String description) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(name) - .orderableTypeParameter("E") - .returnType(arrayType(new TypeSignature("E"))) - .argumentType(new TypeSignature("E")) - .argumentType(BIGINT) - .build()) - .description(description) - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(BIGINT) - .intermediateType(arrayType(new TypeSignature("E"))) - .build()); - this.min = min; - } - - @Override - public FunctionDependencyDeclaration getFunctionDependencies(BoundSignature boundSignature) - { - return MinMaxCompare.getMinMaxCompareFunctionDependencies(boundSignature.getArgumentTypes().get(0).getTypeSignature(), min); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) - { - Type type = boundSignature.getArgumentTypes().get(0); - MethodHandle compare = getMinMaxCompare(functionDependencies, type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), min); - MinMaxNStateSerializer stateSerializer = new MinMaxNStateSerializer(compare, type); - ArrayType outputType = new ArrayType(type); - - MethodHandle inputFunction = INPUT_FUNCTION.bindTo(compare).bindTo(type); - inputFunction = normalizeInputMethod(inputFunction, boundSignature, STATE, BLOCK_INPUT_CHANNEL, INPUT_CHANNEL, BLOCK_INDEX); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION.bindTo(outputType), - ImmutableList.of(new AccumulatorStateDescriptor<>( - MinMaxNState.class, - stateSerializer, - new MinMaxNStateFactory()))); - } - - public static void input(MethodHandle compare, Type type, MinMaxNState state, Block block, long n, int blockIndex) - { - TypedHeap heap = state.getTypedHeap(); - if (heap == null) { - if (n <= 0) { - throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "second argument of max_n/min_n must be positive"); - } - checkCondition(n <= MAX_NUMBER_OF_VALUES, INVALID_FUNCTION_ARGUMENT, "second argument of max_n/min_n must be less than or equal to %s; found %s", MAX_NUMBER_OF_VALUES, n); - heap = new TypedHeap(compare, type, toIntExact(n)); - state.setTypedHeap(heap); - } - long startSize = heap.getEstimatedSize(); - heap.add(block, blockIndex); - state.addMemoryUsage(heap.getEstimatedSize() - startSize); - } - - public static void combine(MinMaxNState state, MinMaxNState otherState) - { - TypedHeap otherHeap = otherState.getTypedHeap(); - if (otherHeap == null) { - return; - } - TypedHeap heap = state.getTypedHeap(); - if (heap == null) { - state.setTypedHeap(otherHeap); - return; - } - long startSize = heap.getEstimatedSize(); - heap.addAll(otherHeap); - state.addMemoryUsage(heap.getEstimatedSize() - startSize); - } - - public static void output(ArrayType outputType, MinMaxNState state, BlockBuilder out) - { - TypedHeap heap = state.getTypedHeap(); - if (heap == null || heap.isEmpty()) { - out.appendNull(); - return; - } - - Type elementType = outputType.getElementType(); - - BlockBuilder reversedBlockBuilder = elementType.createBlockBuilder(null, heap.getCapacity()); - long startSize = heap.getEstimatedSize(); - heap.popAll(reversedBlockBuilder); - state.addMemoryUsage(heap.getEstimatedSize() - startSize); - - BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - for (int i = reversedBlockBuilder.getPositionCount() - 1; i >= 0; i--) { - elementType.appendTo(reversedBlockBuilder, i, arrayBlockBuilder); - } - out.closeEntry(); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 389a00f32752..44949ac5d81d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -13,18 +13,31 @@ */ package io.trino.operator.aggregation; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.MoreCollectors; import io.airlift.log.Logger; +import io.trino.metadata.FunctionBinding; +import io.trino.metadata.FunctionDependencies; import io.trino.metadata.Signature; import io.trino.operator.ParametricImplementationsGroup; +import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; +import io.trino.operator.aggregation.state.InOutStateSerializer; import io.trino.operator.annotations.FunctionsParserHelper; +import io.trino.operator.annotations.ImplementationDependency; +import io.trino.operator.annotations.TypeImplementationDependency; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.AccumulatorStateMetadata; +import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; import io.trino.spi.function.CombineFunction; import io.trino.spi.function.FunctionDependency; +import io.trino.spi.function.InOut; import io.trino.spi.function.InputFunction; import io.trino.spi.function.LiteralParameter; import io.trino.spi.function.OperatorDependency; @@ -35,21 +48,36 @@ import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; import java.lang.reflect.Method; +import java.lang.reflect.Parameter; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; import java.util.stream.IntStream; +import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.emptyToNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation; +import static io.trino.operator.aggregation.state.StateCompiler.generateInOutStateFactory; +import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory; +import static io.trino.operator.aggregation.state.StateCompiler.generateStateSerializer; +import static io.trino.operator.aggregation.state.StateCompiler.getMetadataAnnotation; import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription; -import static java.util.Collections.nCopies; +import static io.trino.operator.annotations.ImplementationDependency.Factory.createDependency; +import static io.trino.operator.annotations.ImplementationDependency.getImplementationDependencyAnnotation; +import static io.trino.operator.annotations.ImplementationDependency.validateImplementationDependencyAnnotation; +import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature; import static java.util.Objects.requireNonNull; +import static java.util.function.Predicate.not; public final class AggregationFromAnnotationsParser { @@ -64,12 +92,12 @@ public static List parseFunctionDefinitions(Class aggr ImmutableList.Builder functions = ImmutableList.builder(); - // There must be a single state class and combine function - Class stateClass = getStateClass(aggregationDefinition); - Optional combineFunction = getCombineFunction(aggregationDefinition, stateClass); + // There must be a single set of state classes and a single combine function + List> stateDetails = getStateDetails(aggregationDefinition); + Optional combineFunction = getCombineFunction(aggregationDefinition, stateDetails); // Each output function defines a new aggregation function - for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) { + for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateDetails)) { AggregationHeader header = parseHeader(aggregationDefinition, outputFunction); if (header.isDecomposable()) { checkArgument(combineFunction.isPresent(), "Decomposable method %s does not have a combine method", header.getName()); @@ -81,11 +109,12 @@ else if (combineFunction.isPresent()) { // Input functions can have either an exact signature, or generic/calculate signature List exactImplementations = new ArrayList<>(); List nonExactImplementations = new ArrayList<>(); - for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) { + for (Method inputFunction : getInputFunctions(aggregationDefinition, stateDetails)) { Optional removeInputFunction = getRemoveInputFunction(aggregationDefinition, inputFunction); AggregationImplementation implementation = parseImplementation( aggregationDefinition, header.getName(), + stateDetails, inputFunction, removeInputFunction, outputFunction, @@ -99,9 +128,9 @@ else if (combineFunction.isPresent()) { } // register a set functions for the canonical name, and each alias - functions.addAll(buildFunctions(header.getName(), header, stateClass, exactImplementations, nonExactImplementations)); + functions.addAll(buildFunctions(header.getName(), header, stateDetails, exactImplementations, nonExactImplementations)); for (String alias : getAliases(aggregationDefinition.getAnnotation(AggregationFunction.class), outputFunction)) { - functions.addAll(buildFunctions(alias, header, stateClass, exactImplementations, nonExactImplementations)); + functions.addAll(buildFunctions(alias, header, stateDetails, exactImplementations, nonExactImplementations)); } } @@ -111,7 +140,7 @@ else if (combineFunction.isPresent()) { private static List buildFunctions( String name, AggregationHeader header, - Class stateClass, + List> stateDetails, List exactImplementations, List nonExactImplementations) { @@ -122,7 +151,7 @@ private static List buildFunctions( functions.add(new ParametricAggregation( exactImplementation.getSignature().withName(name), header, - stateClass, + stateDetails, ParametricImplementationsGroup.of(exactImplementation).withAlias(name))); } @@ -134,7 +163,7 @@ private static List buildFunctions( functions.add(new ParametricAggregation( implementations.getSignature().withName(name), header, - stateClass, + stateDetails, implementations.withAlias(name))); } @@ -180,48 +209,95 @@ private static List getAliases(AggregationFunction aggregationAnnotation return ImmutableList.copyOf(aggregationAnnotation.alias()); } - private static Optional getCombineFunction(Class clazz, Class stateClass) + private static Optional getCombineFunction(Class clazz, List> stateDetails) { List combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class); - for (Method combineFunction : combineFunctions) { - // verify parameter types - List> parameterTypes = getNonDependencyParameterTypes(combineFunction); - List> expectedParameterTypes = nCopies(2, stateClass); - checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction); + if (combineFunctions.isEmpty()) { + return Optional.empty(); + } + checkArgument(combineFunctions.size() == 1, "There must be only one @CombineFunction in class %s", clazz.toGenericString()); + Method combineFunction = getOnlyElement(combineFunctions); + + // verify parameter types + List> parameterTypes = getNonDependencyParameterTypes(combineFunction); + List> expectedParameterTypes = Stream.concat(stateDetails.stream(), stateDetails.stream()) + .map(AccumulatorStateDetails::getStateClass) + .collect(toImmutableList()); + checkArgument(parameterTypes.equals(expectedParameterTypes), + "Expected combine function non-dependency parameters to be %s: %s", + expectedParameterTypes, + combineFunction); + + // legacy combine functions did not require parameters to be fully annotated + if (stateDetails.size() > 1) { + List> parameterAnnotations = getNonDependencyParameterAnnotations(combineFunction); + List> actualStateDetails = new ArrayList<>(); + for (int parameterIndex = 0; parameterIndex < parameterTypes.size(); parameterIndex++) { + actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex).asSubclass(AccumulatorState.class), parameterAnnotations.get(parameterIndex), combineFunction, true)); + } + List> expectedStateDetails = ImmutableList.>builder().addAll(stateDetails).addAll(stateDetails).build(); + checkArgument(actualStateDetails.equals(expectedStateDetails), "Expected combine function to have state parameters %s, but has %s", stateDetails, expectedStateDetails); } - checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateClass.toGenericString()); - return combineFunctions.stream().findFirst(); + return Optional.of(combineFunction); } - private static List getOutputFunctions(Class clazz, Class stateClass) + private static List getOutputFunctions(Class clazz, List> stateDetails) { List outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class); for (Method outputFunction : outputFunctions) { // verify parameter types List> parameterTypes = getNonDependencyParameterTypes(outputFunction); List> expectedParameterTypes = ImmutableList.>builder() - .add(stateClass) + .addAll(stateDetails.stream().map(AccumulatorStateDetails::getStateClass).collect(toImmutableList())) .add(BlockBuilder.class) .build(); checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected output function non-dependency parameters to be %s: %s", expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()), outputFunction); + + // legacy output functions did not require parameters to be fully annotated + if (stateDetails.size() > 1) { + List> parameterAnnotations = getNonDependencyParameterAnnotations(outputFunction); + + List> actualStateDetails = new ArrayList<>(); + for (int parameterIndex = 0; parameterIndex < stateDetails.size(); parameterIndex++) { + actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex).asSubclass(AccumulatorState.class), parameterAnnotations.get(parameterIndex), outputFunction, true)); + } + checkArgument(actualStateDetails.equals(stateDetails), "Expected output function to have state parameters %s, but has %s", stateDetails, actualStateDetails); + } } checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions"); return outputFunctions; } - private static List getInputFunctions(Class clazz, Class stateClass) + private static List getInputFunctions(Class clazz, List> stateDetails) { List inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class); for (Method inputFunction : inputFunctions) { - // verify state parameter is first non-dependency parameter - Class actualStateType = getNonDependencyParameterTypes(inputFunction).get(0); - checkArgument(stateClass.equals(actualStateType), - "Expected input function non-dependency parameters to begin with state type %s: %s", - stateClass.getSimpleName(), + // verify state parameter types + List> parameterTypes = getNonDependencyParameterTypes(inputFunction) + .subList(0, stateDetails.size()); + List> expectedParameterTypes = ImmutableList.>builder() + .addAll(stateDetails.stream().map(AccumulatorStateDetails::getStateClass).collect(toImmutableList())) + .build() + .subList(0, stateDetails.size()); + checkArgument(parameterTypes.equals(expectedParameterTypes), + "Expected input function non-dependency parameters to begin with state types %s: %s", + expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()), inputFunction); + + // g input functions did not require parameters to be fully annotated + if (stateDetails.size() > 1) { + List> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction) + .subList(0, stateDetails.size()); + + List> actualStateDetails = new ArrayList<>(); + for (int parameterIndex = 0; parameterIndex < stateDetails.size(); parameterIndex++) { + actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex).asSubclass(AccumulatorState.class), parameterAnnotations.get(parameterIndex), inputFunction, false)); + } + checkArgument(actualStateDetails.equals(stateDetails), "Expected input function to have state parameters %s, but has %s", stateDetails, actualStateDetails); + } } checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions"); @@ -246,6 +322,14 @@ private static List> getNonDependencyParameterTypes(Method function) .collect(toImmutableList()); } + private static List> getNonDependencyParameterAnnotations(Method function) + { + Annotation[][] parameterAnnotations = function.getParameterAnnotations(); + return getNonDependencyParameters(function) + .mapToObj(index -> ImmutableList.copyOf(parameterAnnotations[index])) + .collect(toImmutableList()); + } + private static Optional getRemoveInputFunction(Class clazz, Method inputFunction) { // Only include methods which take the same parameters as the corresponding input function @@ -255,20 +339,317 @@ private static Optional getRemoveInputFunction(Class clazz, Method in .collect(MoreCollectors.toOptional()); } - private static Class getStateClass(Class clazz) + private static List> getStateDetails(Class clazz) { - ImmutableSet.Builder> builder = ImmutableSet.builder(); + ImmutableSet.Builder>> builder = ImmutableSet.builder(); for (Method inputFunction : FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) { - checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters"); - Class stateClass = AggregationImplementation.Parser.findAggregationStateParamType(inputFunction); + List> parameterTypes = getNonDependencyParameterTypes(inputFunction); + checkArgument(!parameterTypes.isEmpty(), "Input function has no parameters"); + List> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction); + + ImmutableList.Builder> stateParameters = ImmutableList.builder(); + for (int parameterIndex = 0; parameterIndex < parameterTypes.size(); parameterIndex++) { + Class parameterType = parameterTypes.get(parameterIndex); + if (!AccumulatorState.class.isAssignableFrom(parameterType)) { + continue; + } + + stateParameters.add(toAccumulatorStateDetails(parameterType.asSubclass(AccumulatorState.class), parameterAnnotations.get(parameterIndex), inputFunction, false)); + } + List> states = stateParameters.build(); + checkArgument(!states.isEmpty(), "Input function must have at least one state parameter"); + builder.add(states); + } + Set>> functionStateClasses = builder.build(); + checkArgument(!functionStateClasses.isEmpty(), "No input functions found"); + checkArgument(functionStateClasses.size() == 1, "There must be exactly one set of @AccumulatorState in class %s", clazz.toGenericString()); + + return getOnlyElement(functionStateClasses); + } + + private static AccumulatorStateDetails toAccumulatorStateDetails( + Class stateClass, + List parameterAnnotations, + Method method, + boolean requireAnnotation) + { + Optional state = parameterAnnotations.stream() + .filter(AggregationState.class::isInstance) + .map(AggregationState.class::cast) + .findFirst(); + + if (requireAnnotation) { + checkArgument(state.isPresent(), "AggregationState must be present on AccumulatorState parameters: %s", method); + } + + List declaredTypeParameters = state.map(AggregationState::value) + .map(ImmutableList::copyOf) + .orElse(ImmutableList.of()); + + return toAccumulatorStateDetails(stateClass, declaredTypeParameters); + } + + @VisibleForTesting + public static AccumulatorStateDetails toAccumulatorStateDetails(Class stateClass, List declaredTypeParameters) + { + StateMetadata metadata = new StateMetadata(getMetadataAnnotation(stateClass)); + // Generic state classes have their own type variables, that must be mapped to the aggregation's type variables + TypeSignatureMapping typeParameterMapping = getTypeParameterMapping(stateClass, declaredTypeParameters, metadata); + + if (stateClass.equals(InOut.class)) { + String typeVariable = typeParameterMapping.mapTypeSignature(new TypeSignature("T")).toString(); + @SuppressWarnings("unchecked") AccumulatorStateDetails stateDetails = (AccumulatorStateDetails) getInOutAccumulatorStateDetails(typeVariable); + return stateDetails; + } + + List allDependencies = new ArrayList<>(); + + BiFunction> serializerGenerator; + if (metadata.getStateSerializerClass().isPresent()) { + Constructor constructor = getOnlyConstructor(metadata.getStateSerializerClass().get()); + List dependencies = parseImplementationDependencies(typeParameterMapping, constructor); + serializerGenerator = new TypedFactory<>(constructor, dependencies); + allDependencies.addAll(dependencies); + } + else { + serializerGenerator = (functionBinding, functionDependencies) -> generateStateSerializer(stateClass); + } + + TypeSignature serializedType; + if (metadata.getSerializedType().isPresent()) { + serializedType = typeParameterMapping.mapTypeSignature(parseTypeSignature(metadata.getSerializedType().get(), ImmutableSet.of())); + } + else { + // serialized type is not explicit declared, so we must construct it to get the + // type, but this will only work if there are no dependencies + checkArgument(allDependencies.isEmpty(), "serializedType must be set for state %s with dependencies", stateClass); + AccumulatorStateSerializer serializer = serializerGenerator.apply(null, null); + serializedType = serializer.getSerializedType().getTypeSignature(); + // since there are no dependencies, the same serializer can be used for all + serializerGenerator = (functionBinding, functionDependencies) -> serializer; + } + + BiFunction> factoryGenerator; + if (metadata.getStateFactoryClass().isPresent()) { + Constructor constructor = getOnlyConstructor(metadata.getStateFactoryClass().get()); + List dependencies = parseImplementationDependencies(typeParameterMapping, constructor); + factoryGenerator = new TypedFactory<>(constructor, dependencies); + allDependencies.addAll(dependencies); + } + else { + factoryGenerator = (functionBinding, functionDependencies) -> generateStateFactory(stateClass); + } + + return new AccumulatorStateDetails<>( + stateClass, + declaredTypeParameters, + serializedType, + serializerGenerator, + factoryGenerator, + allDependencies); + } + + private static Constructor getOnlyConstructor(Class clazz) + { + Constructor[] constructors = clazz.getConstructors(); + checkArgument(constructors.length == 1, "Expected %s to have only one public constructor", clazz.getSimpleName()); + return constructors[0]; + } + + private static AccumulatorStateDetails getInOutAccumulatorStateDetails(String typeVariable) + { + TypeSignature serializedType = parseTypeSignature(typeVariable, ImmutableSet.of()); + return new AccumulatorStateDetails<>( + InOut.class, + ImmutableList.of(typeVariable), + serializedType, + (functionBinding, functionDependencies) -> new InOutStateSerializer(functionBinding.getTypeVariable(typeVariable)), + (functionBinding, functionDependencies) -> generateInOutStateFactory(functionBinding.getTypeVariable(typeVariable)), + ImmutableList.of(new TypeImplementationDependency(parseTypeSignature(typeVariable, ImmutableSet.of())))); + } - checkArgument(AccumulatorState.class.isAssignableFrom(stateClass), "stateClass is not a subclass of AccumulatorState"); - builder.add(stateClass.asSubclass(AccumulatorState.class)); + private static TypeSignatureMapping getTypeParameterMapping(Class stateClass, List declaredTypeParameters, StateMetadata metadata) + { + List expectedTypeParameters = metadata.getTypeParameters(); + if (expectedTypeParameters.isEmpty()) { + return new TypeSignatureMapping(ImmutableMap.of()); + } + checkArgument(declaredTypeParameters.size() == expectedTypeParameters.size(), "AggregationState %s requires %s type parameters", stateClass, expectedTypeParameters.size()); + + ImmutableMap.Builder mapping = ImmutableMap.builder(); + for (int parameterIndex = 0; parameterIndex < declaredTypeParameters.size(); parameterIndex++) { + String declaredTypeParameter = declaredTypeParameters.get(parameterIndex); + String expectedTypeParameter = expectedTypeParameters.get(parameterIndex); + mapping.put(expectedTypeParameter, declaredTypeParameter); } - ImmutableSet> stateClasses = builder.build(); - checkArgument(!stateClasses.isEmpty(), "No input functions found"); - checkArgument(stateClasses.size() == 1, "There must be exactly one @AccumulatorState in class %s", clazz.toGenericString()); + return new TypeSignatureMapping(mapping.buildOrThrow()); + } + + public static List parseImplementationDependencies(TypeSignatureMapping typeSignatureMapping, Executable inputFunction) + { + ImmutableList.Builder builder = ImmutableList.builder(); + + for (Parameter parameter : inputFunction.getParameters()) { + getImplementationDependencyAnnotation(parameter).ifPresent(annotation -> { + // check if only declared typeParameters and literalParameters are used + validateImplementationDependencyAnnotation( + inputFunction, + annotation, + typeSignatureMapping.getTypeParameters(), + ImmutableSet.of()); + ImplementationDependency dependency = createDependency(annotation, ImmutableSet.of(), parameter.getType()); + dependency = typeSignatureMapping.mapTypes(dependency); + builder.add(dependency); + }); + } + return builder.build(); + } + + private static class StateMetadata + { + private final Optional>> stateSerializerClass; + private final Optional>> stateFactoryClass; + private final List typeParameters; + private final Optional serializedType; + + public StateMetadata(AccumulatorStateMetadata metadata) + { + if (metadata == null) { + stateSerializerClass = Optional.empty(); + stateFactoryClass = Optional.empty(); + typeParameters = ImmutableList.of(); + serializedType = Optional.empty(); + } + else { + //noinspection unchecked + stateSerializerClass = Optional.of(metadata.stateSerializerClass()) + .filter(not(AccumulatorStateSerializer.class::equals)) + .map(type -> (Class>) type); + //noinspection unchecked + stateFactoryClass = Optional.of(metadata.stateFactoryClass()) + .filter(not(AccumulatorStateFactory.class::equals)) + .map(type -> (Class>) type); + typeParameters = ImmutableList.copyOf(metadata.typeParameters()); + serializedType = Optional.of(metadata.serializedType()) + .filter(not(String::isEmpty)); + } + } + + public Optional>> getStateSerializerClass() + { + return stateSerializerClass; + } + + public Optional>> getStateFactoryClass() + { + return stateFactoryClass; + } + + public List getTypeParameters() + { + return typeParameters; + } + + public Optional getSerializedType() + { + return serializedType; + } + } - return getOnlyElement(stateClasses); + public static class AccumulatorStateDetails + { + private final Class stateClass; + private final List typeParameters; + private final TypeSignature serializedType; + private final BiFunction> serializerGenerator; + private final BiFunction> factoryGenerator; + private final List dependencies; + + public AccumulatorStateDetails( + Class stateClass, + List typeParameters, + TypeSignature serializedType, + BiFunction> serializerGenerator, + BiFunction> factoryGenerator, + List dependencies) + { + this.stateClass = requireNonNull(stateClass, "stateClass is null"); + this.typeParameters = ImmutableList.copyOf(requireNonNull(typeParameters, "typeParameters is null")); + this.serializedType = requireNonNull(serializedType, "serializedType is null"); + this.serializerGenerator = requireNonNull(serializerGenerator, "serializerGenerator is null"); + this.factoryGenerator = requireNonNull(factoryGenerator, "factoryGenerator is null"); + this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); + } + + public Class getStateClass() + { + return stateClass; + } + + public TypeSignature getSerializedType() + { + return serializedType; + } + + public List getDependencies() + { + return dependencies; + } + + public AccumulatorStateDescriptor createAccumulatorStateDescriptor(FunctionBinding functionBinding, FunctionDependencies functionDependencies) + { + return new AccumulatorStateDescriptor<>( + stateClass, + serializerGenerator.apply(functionBinding, functionDependencies), + factoryGenerator.apply(functionBinding, functionDependencies)); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AccumulatorStateDetails that = (AccumulatorStateDetails) o; + return Objects.equals(stateClass, that.stateClass) && Objects.equals(typeParameters, that.typeParameters); + } + + @Override + public int hashCode() + { + return Objects.hash(stateClass, typeParameters); + } + } + + private static class TypedFactory + implements BiFunction + { + private final Constructor constructor; + private final List dependencies; + + public TypedFactory(Constructor constructor, List dependencies) + { + //noinspection unchecked + this.constructor = (Constructor) constructor; + this.dependencies = dependencies; + } + + @Override + public T apply(FunctionBinding functionBinding, FunctionDependencies functionDependencies) + { + List values = dependencies.stream() + .map(dependency -> dependency.resolve(functionBinding, functionDependencies)) + .collect(toImmutableList()); + + try { + return constructor.newInstance(values.toArray()); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java index 9350b50e31f8..20542d890fd8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java @@ -19,6 +19,7 @@ import io.trino.metadata.FunctionNullability; import io.trino.metadata.Signature; import io.trino.operator.ParametricImplementation; +import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; import io.trino.operator.annotations.ImplementationDependency; import io.trino.spi.block.Block; @@ -37,6 +38,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Optional; import java.util.Set; @@ -215,10 +217,12 @@ public boolean areTypesAssignable(BoundSignature boundSignature) Class methodDeclaredType = argumentNativeContainerTypes.get(i).getJavaType(); boolean isCurrentBlockPosition = argumentNativeContainerTypes.get(i).isBlockPosition(); + // block and position works for any type, but if block is annotated with SqlType nativeContainerType, then only types with the + // specified container type match if (isCurrentBlockPosition && methodDeclaredType.isAssignableFrom(Block.class)) { continue; } - if (!isCurrentBlockPosition && methodDeclaredType.isAssignableFrom(argumentType)) { + if (methodDeclaredType.isAssignableFrom(argumentType)) { continue; } return false; @@ -267,6 +271,7 @@ public static final class Parser private Parser( Class aggregationDefinition, String name, + List> stateDetails, Method inputFunction, Optional removeInputFunction, Method outputFunction, @@ -294,6 +299,7 @@ private Parser( parseLongVariableConstraints(inputFunction, signatureBuilder); List allDependencies = Stream.of( + stateDetails.stream().map(AccumulatorStateDetails::getDependencies).flatMap(Collection::stream), inputDependencies.stream(), removeInputDependencies.stream(), outputDependencies.stream(), @@ -337,12 +343,13 @@ private AggregationImplementation get() public static AggregationImplementation parseImplementation( Class aggregationDefinition, String name, + List> stateDetails, Method inputFunction, Optional removeInputFunction, Method outputFunction, Optional combineFunction) { - return new Parser(aggregationDefinition, name, inputFunction, removeInputFunction, outputFunction, combineFunction).get(); + return new Parser(aggregationDefinition, name, stateDetails, inputFunction, removeInputFunction, outputFunction, combineFunction).get(); } private static List parseInputParameterKinds(Method method) @@ -441,7 +448,16 @@ public static List parseSignatureArgumentsTypes(Me continue; } - builder.add(new AggregateNativeContainerType(inputFunction.getParameterTypes()[i], isParameterBlock(annotations))); + Optional> nativeContainerType = Arrays.stream(annotations) + .filter(SqlType.class::isInstance) + .map(SqlType.class::cast) + .findFirst() + .map(SqlType::nativeContainerType); + // Note: this cannot be done as a chain due to strange generic type mismatches + if (nativeContainerType.isPresent() && !nativeContainerType.get().equals(Object.class)) { + parameterType = nativeContainerType.get(); + } + builder.add(new AggregateNativeContainerType(parameterType, isParameterBlock(annotations))); } return builder.build(); @@ -501,11 +517,6 @@ public List getInputTypesSignatures(Method inputFunction) return builder.build(); } - public static Class findAggregationStateParamType(Method inputFunction) - { - return inputFunction.getParameterTypes()[findAggregationStateParamId(inputFunction)]; - } - public static int findAggregationStateParamId(Method method) { return findAggregationStateParamId(method, 0); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java index 4f7f2ab7a757..4d6af393bcc0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java @@ -13,192 +13,53 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.state.BlockPositionState; -import io.trino.operator.aggregation.state.BlockPositionStateSerializer; -import io.trino.operator.aggregation.state.GenericBooleanState; -import io.trino.operator.aggregation.state.GenericBooleanStateSerializer; -import io.trino.operator.aggregation.state.GenericDoubleState; -import io.trino.operator.aggregation.state.GenericDoubleStateSerializer; -import io.trino.operator.aggregation.state.GenericLongState; -import io.trino.operator.aggregation.state.GenericLongStateSerializer; -import io.trino.operator.aggregation.state.StateCompiler; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; - -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static io.trino.util.Reflection.methodHandle; - -public class ArbitraryAggregationFunction - extends SqlAggregationFunction +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +@AggregationFunction("arbitrary") +@Description("Return an arbitrary non-null input value") +public final class ArbitraryAggregationFunction { - public static final ArbitraryAggregationFunction ARBITRARY_AGGREGATION = new ArbitraryAggregationFunction(); - private static final String NAME = "arbitrary"; - - private static final MethodHandle LONG_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, GenericLongState.class, Block.class, int.class); - private static final MethodHandle DOUBLE_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, GenericDoubleState.class, Block.class, int.class); - private static final MethodHandle BOOLEAN_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, GenericBooleanState.class, Block.class, int.class); - private static final MethodHandle BLOCK_POSITION_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, BlockPositionState.class, Block.class, int.class); - - private static final MethodHandle LONG_OUTPUT_FUNCTION = methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class); - private static final MethodHandle DOUBLE_OUTPUT_FUNCTION = methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class); - private static final MethodHandle BOOLEAN_OUTPUT_FUNCTION = methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class); - private static final MethodHandle BLOCK_POSITION_OUTPUT_FUNCTION = methodHandle(BlockPositionState.class, "write", Type.class, BlockPositionState.class, BlockBuilder.class); - - private static final MethodHandle LONG_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", GenericLongState.class, GenericLongState.class); - private static final MethodHandle DOUBLE_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", GenericDoubleState.class, GenericDoubleState.class); - private static final MethodHandle BOOLEAN_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", GenericBooleanState.class, GenericBooleanState.class); - private static final MethodHandle BLOCK_POSITION_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", BlockPositionState.class, BlockPositionState.class); - - protected ArbitraryAggregationFunction() + private ArbitraryAggregationFunction() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @AggregationState("T") InOut state, + @BlockPosition @SqlType("T") Block block, + @BlockIndex int position) + throws Throwable { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .typeVariable("T") - .returnType(new TypeSignature("T")) - .argumentType(new TypeSignature("T")) - .build()) - .description("Return an arbitrary non-null input value") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(new TypeSignature("T")) - .build()); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Type type = boundSignature.getReturnType(); - - MethodHandle inputFunction; - MethodHandle combineFunction; - MethodHandle outputFunction; - AccumulatorStateDescriptor accumulatorStateDescriptor; - - if (type.getJavaType() == long.class) { - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - GenericLongState.class, - new GenericLongStateSerializer(type), - StateCompiler.generateStateFactory(GenericLongState.class)); - inputFunction = LONG_INPUT_FUNCTION; - combineFunction = LONG_COMBINE_FUNCTION; - outputFunction = LONG_OUTPUT_FUNCTION; - } - else if (type.getJavaType() == double.class) { - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - GenericDoubleState.class, - new GenericDoubleStateSerializer(type), - StateCompiler.generateStateFactory(GenericDoubleState.class)); - inputFunction = DOUBLE_INPUT_FUNCTION; - combineFunction = DOUBLE_COMBINE_FUNCTION; - outputFunction = DOUBLE_OUTPUT_FUNCTION; - } - else if (type.getJavaType() == boolean.class) { - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - GenericBooleanState.class, - new GenericBooleanStateSerializer(type), - StateCompiler.generateStateFactory(GenericBooleanState.class)); - inputFunction = BOOLEAN_INPUT_FUNCTION; - combineFunction = BOOLEAN_COMBINE_FUNCTION; - outputFunction = BOOLEAN_OUTPUT_FUNCTION; - } - else { - // native container type is Slice or Block - accumulatorStateDescriptor = new AccumulatorStateDescriptor<>( - BlockPositionState.class, - new BlockPositionStateSerializer(type), - StateCompiler.generateStateFactory(BlockPositionState.class)); - inputFunction = BLOCK_POSITION_INPUT_FUNCTION; - combineFunction = BLOCK_POSITION_COMBINE_FUNCTION; - outputFunction = BLOCK_POSITION_OUTPUT_FUNCTION; - } - inputFunction = inputFunction.bindTo(type); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(combineFunction), - outputFunction.bindTo(type), - ImmutableList.of(accumulatorStateDescriptor)); - } - - public static void input(Type type, GenericDoubleState state, Block block, int position) - { - if (!state.isNull()) { - return; + if (state.isNull()) { + state.set(block, position); } - state.setNull(false); - state.setValue(type.getDouble(block, position)); } - public static void input(Type type, GenericLongState state, Block block, int position) + @CombineFunction + public static void combine( + @AggregationState("T") InOut state, + @AggregationState("T") InOut otherState) + throws Throwable { - if (!state.isNull()) { - return; + if (state.isNull()) { + state.set(otherState); } - state.setNull(false); - state.setValue(type.getLong(block, position)); } - public static void input(Type type, GenericBooleanState state, Block block, int position) + @OutputFunction("T") + public static void output(@AggregationState("T") InOut state, BlockBuilder out) { - if (!state.isNull()) { - return; - } - state.setNull(false); - state.setValue(type.getBoolean(block, position)); - } - - public static void input(Type type, BlockPositionState state, Block block, int position) - { - if (state.getBlock() != null) { - return; - } - state.setBlock(block); - state.setPosition(position); - } - - public static void combine(GenericLongState state, GenericLongState otherState) - { - if (!state.isNull()) { - return; - } - state.set(otherState); - } - - public static void combine(GenericDoubleState state, GenericDoubleState otherState) - { - if (!state.isNull()) { - return; - } - state.set(otherState); - } - - public static void combine(GenericBooleanState state, GenericBooleanState otherState) - { - if (!state.isNull()) { - return; - } - state.set(otherState); - } - - public static void combine(BlockPositionState state, BlockPositionState otherState) - { - if (state.getBlock() != null) { - return; - } - state.set(otherState); + state.get(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java index 5e1aa28b8925..498fd91a6b12 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java @@ -14,96 +14,75 @@ package io.trino.operator.aggregation; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.aggregation.state.NullableLongState; -import io.trino.operator.aggregation.state.StateCompiler; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.TypeSignature; -import io.trino.type.BlockTypeOperators; -import io.trino.type.BlockTypeOperators.BlockPositionXxHash64; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.airlift.slice.Slices.wrappedLongArray; -import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.util.Reflection.methodHandle; -import static java.util.Objects.requireNonNull; -public class ChecksumAggregationFunction - extends SqlAggregationFunction +@AggregationFunction("checksum") +@Description("Checksum of the given values") +public final class ChecksumAggregationFunction { @VisibleForTesting public static final long PRIME64 = 0x9E3779B185EBCA87L; - private static final String NAME = "checksum"; - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ChecksumAggregationFunction.class, "output", NullableLongState.class, BlockBuilder.class); - private static final MethodHandle INPUT_FUNCTION = methodHandle(ChecksumAggregationFunction.class, "input", BlockPositionXxHash64.class, NullableLongState.class, Block.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(ChecksumAggregationFunction.class, "combine", NullableLongState.class, NullableLongState.class); - private final BlockTypeOperators blockTypeOperators; + private ChecksumAggregationFunction() {} - public ChecksumAggregationFunction(BlockTypeOperators blockTypeOperators) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .comparableTypeParameter("T") - .returnType(VARBINARY) - .argumentType(new TypeSignature("T")) - .build()) - .argumentNullability(true) - .description("Checksum of the given values") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(BIGINT) - .build()); - this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - BlockPositionXxHash64 xxHash64Operator = blockTypeOperators.getXxHash64Operator(boundSignature.getArgumentTypes().get(0)); - AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(NullableLongState.class); - return new AggregationMetadata( - INPUT_FUNCTION.bindTo(xxHash64Operator), - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION, - ImmutableList.of(new AccumulatorStateDescriptor<>( - NullableLongState.class, - stateSerializer, - StateCompiler.generateStateFactory(NullableLongState.class)))); - } - - public static void input(BlockPositionXxHash64 xxHash64Operator, NullableLongState state, Block block, int position) + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency( + operator = OperatorType.XX_HASH_64, + argumentTypes = "T", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) + MethodHandle xxHash64Operator, + @AggregationState NullableLongState state, + @NullablePosition @BlockPosition @SqlType("T") Block block, + @BlockIndex int position) + throws Throwable { state.setNull(false); if (block.isNull(position)) { state.setValue(state.getValue() + PRIME64); } else { - state.setValue(state.getValue() + xxHash64Operator.xxHash64(block, position) * PRIME64); + long valueHash = (long) xxHash64Operator.invokeExact(block, position); + state.setValue(state.getValue() + valueHash * PRIME64); } } - public static void combine(NullableLongState state, NullableLongState otherState) + @CombineFunction + public static void combine( + @AggregationState NullableLongState state, + @AggregationState NullableLongState otherState) { state.setNull(state.isNull() && otherState.isNull()); state.setValue(state.getValue() + otherState.getValue()); } - public static void output(NullableLongState state, BlockBuilder out) + @OutputFunction("VARBINARY") + public static void output( + @AggregationState NullableLongState state, + BlockBuilder out) { if (state.isNull()) { out.appendNull(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java index 0754d26e56d6..9163dab81c84 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java @@ -13,87 +13,56 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.aggregation.state.LongState; -import io.trino.operator.aggregation.state.StateCompiler; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorStateFactory; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.TypeSignature; - -import java.lang.invoke.MethodHandle; -import java.util.Optional; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.RemoveInputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.util.Reflection.methodHandle; -public class CountColumn - extends SqlAggregationFunction +@AggregationFunction("count") +@Description("Counts the non-null values") +public final class CountColumn { - public static final CountColumn COUNT_COLUMN = new CountColumn(); - private static final String NAME = "count"; - private static final MethodHandle INPUT_FUNCTION = methodHandle(CountColumn.class, "input", LongState.class, Block.class, int.class); - private static final MethodHandle REMOVE_INPUT_FUNCTION = methodHandle(CountColumn.class, "removeInput", LongState.class, Block.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(CountColumn.class, "combine", LongState.class, LongState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(CountColumn.class, "output", LongState.class, BlockBuilder.class); - - public CountColumn() - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .typeVariable("T") - .returnType(BIGINT) - .argumentType(new TypeSignature("T")) - .build()) - .description("Counts the non-null values") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(BIGINT) - .build()); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(LongState.class); - AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(LongState.class); - - return new AggregationMetadata( - INPUT_FUNCTION, - Optional.of(REMOVE_INPUT_FUNCTION), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION, - ImmutableList.of(new AccumulatorStateDescriptor<>( - LongState.class, - stateSerializer, - stateFactory))); - } + private CountColumn() {} - public static void input(LongState state, Block block, int index) + @InputFunction + @TypeParameter("T") + public static void input( + @AggregationState LongState state, + @BlockPosition @SqlType("T") Block block, + @BlockIndex int position) { state.setValue(state.getValue() + 1); } - public static void removeInput(LongState state, Block block, int index) + @RemoveInputFunction + public static void removeInput( + @AggregationState LongState state, + @BlockPosition @SqlType("T") Block block, + @BlockIndex int position) { state.setValue(state.getValue() - 1); } - public static void combine(LongState state, LongState otherState) + @CombineFunction + public static void combine(@AggregationState LongState state, LongState otherState) { state.setValue(state.getValue() + otherState.getValue()); } - public static void output(LongState state, BlockBuilder out) + @OutputFunction("BIGINT") + public static void output(@AggregationState LongState state, BlockBuilder out) { BIGINT.writeLong(out, state.getValue()); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java index bc618afe8099..fb930e02ae0a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java @@ -14,106 +14,50 @@ package io.trino.operator.aggregation; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState; -import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory; -import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateSerializer; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.LiteralParameters; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import java.lang.invoke.MethodHandle; import java.math.BigDecimal; import java.math.BigInteger; -import java.util.Optional; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.trino.spi.type.Decimals.overflows; import static io.trino.spi.type.Decimals.writeShortDecimal; import static io.trino.spi.type.Int128Math.addWithOverflow; import static io.trino.spi.type.Int128Math.divideRoundUp; -import static io.trino.spi.type.TypeSignatureParameter.typeVariable; -import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.util.Reflection.methodHandle; import static java.math.BigDecimal.ROUND_HALF_UP; -public class DecimalAverageAggregation - extends SqlAggregationFunction +@AggregationFunction("avg") +@Description("Calculates the average value") +public final class DecimalAverageAggregation { - public static final DecimalAverageAggregation DECIMAL_AVERAGE_AGGREGATION = new DecimalAverageAggregation(); - - private static final String NAME = "avg"; - private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class); - private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class); - - private static final MethodHandle SHORT_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "outputShortDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class); - private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "outputLongDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class); - - private static final MethodHandle COMBINE_FUNCTION = methodHandle(DecimalAverageAggregation.class, "combine", LongDecimalWithOverflowAndLongState.class, LongDecimalWithOverflowAndLongState.class); - private static final BigInteger TWO = new BigInteger("2"); private static final BigInteger OVERFLOW_MULTIPLIER = TWO.pow(128); - public DecimalAverageAggregation() - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .returnType(new TypeSignature("decimal", typeVariable("p"), typeVariable("s"))) - .argumentType(new TypeSignature("decimal", typeVariable("p"), typeVariable("s"))) - .build()) - .description("Calculates the average value") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(VARBINARY) - .build()); - } + private DecimalAverageAggregation() {} - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Type type = getOnlyElement(boundSignature.getArgumentTypes()); - checkArgument(type instanceof DecimalType, "type must be Decimal"); - MethodHandle inputFunction; - MethodHandle outputFunction; - Class stateInterface = LongDecimalWithOverflowAndLongState.class; - LongDecimalWithOverflowAndLongStateSerializer stateSerializer = new LongDecimalWithOverflowAndLongStateSerializer(); - - if (((DecimalType) type).isShort()) { - inputFunction = SHORT_DECIMAL_INPUT_FUNCTION; - outputFunction = SHORT_DECIMAL_OUTPUT_FUNCTION; - } - else { - inputFunction = LONG_DECIMAL_INPUT_FUNCTION; - outputFunction = LONG_DECIMAL_OUTPUT_FUNCTION; - } - outputFunction = outputFunction.bindTo(type); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - outputFunction, - ImmutableList.of(new AccumulatorStateDescriptor<>( - stateInterface, - stateSerializer, - new LongDecimalWithOverflowAndLongStateFactory()))); - } - - public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state, Block block, int position) + @InputFunction + @LiteralParameters({"p", "s"}) + public static void inputShortDecimal( + @AggregationState LongDecimalWithOverflowAndLongState state, + @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = long.class) Block block, + @BlockIndex int position) { state.addLong(1); // row counter @@ -136,7 +80,12 @@ public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state, state.addOverflow(overflow); } - public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, Block block, int position) + @InputFunction + @LiteralParameters({"p", "s"}) + public static void inputLongDecimal( + @AggregationState LongDecimalWithOverflowAndLongState state, + @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = Int128.class) Block block, + @BlockIndex int position) { state.addLong(1); // row counter @@ -159,7 +108,8 @@ public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, B state.addOverflow(overflow); } - public static void combine(LongDecimalWithOverflowAndLongState state, LongDecimalWithOverflowAndLongState otherState) + @CombineFunction + public static void combine(@AggregationState LongDecimalWithOverflowAndLongState state, @AggregationState LongDecimalWithOverflowAndLongState otherState) { state.addLong(otherState.getLong()); // row counter @@ -187,23 +137,23 @@ public static void combine(LongDecimalWithOverflowAndLongState state, LongDecima } } - public static void outputShortDecimal(DecimalType type, LongDecimalWithOverflowAndLongState state, BlockBuilder out) + @OutputFunction("decimal(p,s)") + public static void outputShortDecimal( + @TypeParameter("decimal(p,s)") Type type, + @AggregationState LongDecimalWithOverflowAndLongState state, + BlockBuilder out) { + DecimalType decimalType = (DecimalType) type; if (state.getLong() == 0) { out.appendNull(); + return; } - else { - writeShortDecimal(out, average(state, type).toLongExact()); - } - } - - public static void outputLongDecimal(DecimalType type, LongDecimalWithOverflowAndLongState state, BlockBuilder out) - { - if (state.getLong() == 0) { - out.appendNull(); + Int128 average = average(state, decimalType); + if (decimalType.isShort()) { + writeShortDecimal(out, average.toLongExact()); } else { - type.writeObject(out, average(state, type)); + type.writeObject(out, average); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java index 08890618db45..ea7ad1e1be77 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java @@ -13,98 +13,42 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; -import io.trino.operator.aggregation.state.LongDecimalWithOverflowStateFactory; -import io.trino.operator.aggregation.state.LongDecimalWithOverflowStateSerializer; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.DecimalType; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.LiteralParameters; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; import io.trino.spi.type.Decimals; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; +import io.trino.spi.type.Int128; -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.trino.spi.type.Int128Math.addWithOverflow; -import static io.trino.spi.type.TypeSignatureParameter.numericParameter; -import static io.trino.spi.type.TypeSignatureParameter.typeVariable; -import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.util.Reflection.methodHandle; -public class DecimalSumAggregation - extends SqlAggregationFunction +@AggregationFunction("sum") +@Description("Calculates the sum over the input values") +public final class DecimalSumAggregation { - public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation(); - private static final String NAME = "sum"; - private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputShortDecimal", LongDecimalWithOverflowState.class, Block.class, int.class); - private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputLongDecimal", LongDecimalWithOverflowState.class, Block.class, int.class); - - private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "outputLongDecimal", LongDecimalWithOverflowState.class, BlockBuilder.class); - - private static final MethodHandle COMBINE_FUNCTION = methodHandle(DecimalSumAggregation.class, "combine", LongDecimalWithOverflowState.class, LongDecimalWithOverflowState.class); - - public DecimalSumAggregation() - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .returnType(new TypeSignature("decimal", numericParameter(38), typeVariable("s"))) - .argumentType(new TypeSignature("decimal", typeVariable("p"), typeVariable("s"))) - .build()) - .description("Calculates the sum over the input values") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(VARBINARY) - .build()); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Type inputType = getOnlyElement(boundSignature.getArgumentTypes()); - checkArgument(inputType instanceof DecimalType, "type must be Decimal"); - MethodHandle inputFunction; - Class stateInterface = LongDecimalWithOverflowState.class; - LongDecimalWithOverflowStateSerializer stateSerializer = new LongDecimalWithOverflowStateSerializer(); - - if (((DecimalType) inputType).isShort()) { - inputFunction = SHORT_DECIMAL_INPUT_FUNCTION; - } - else { - inputFunction = LONG_DECIMAL_INPUT_FUNCTION; - } - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - LONG_DECIMAL_OUTPUT_FUNCTION, - ImmutableList.of(new AccumulatorStateDescriptor<>( - stateInterface, - stateSerializer, - new LongDecimalWithOverflowStateFactory()))); - } + private DecimalSumAggregation() {} - public static void inputShortDecimal(LongDecimalWithOverflowState state, Block block, int position) + @InputFunction + @LiteralParameters({"p", "s"}) + public static void inputShortDecimal( + @AggregationState LongDecimalWithOverflowState state, + @SqlType("decimal(p,s)") long rightLow) { state.setNotNull(); long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); - long rightLow = block.getLong(position, 0); long rightHigh = rightLow >> 63; long overflow = addWithOverflow( @@ -117,7 +61,12 @@ public static void inputShortDecimal(LongDecimalWithOverflowState state, Block b state.setOverflow(Math.addExact(overflow, state.getOverflow())); } - public static void inputLongDecimal(LongDecimalWithOverflowState state, Block block, int position) + @InputFunction + @LiteralParameters({"p", "s"}) + public static void inputLongDecimal( + @AggregationState LongDecimalWithOverflowState state, + @BlockPosition @SqlType(value = "decimal(p,s)", nativeContainerType = Int128.class) Block block, + @BlockIndex int position) { state.setNotNull(); @@ -138,7 +87,8 @@ public static void inputLongDecimal(LongDecimalWithOverflowState state, Block bl state.addOverflow(overflow); } - public static void combine(LongDecimalWithOverflowState state, LongDecimalWithOverflowState otherState) + @CombineFunction + public static void combine(@AggregationState LongDecimalWithOverflowState state, @AggregationState LongDecimalWithOverflowState otherState) { long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); @@ -164,7 +114,8 @@ public static void combine(LongDecimalWithOverflowState state, LongDecimalWithOv } } - public static void outputLongDecimal(LongDecimalWithOverflowState state, BlockBuilder out) + @OutputFunction("decimal(38,s)") + public static void outputLongDecimal(@AggregationState LongDecimalWithOverflowState state, BlockBuilder out) { if (state.isNotNull()) { if (state.getOverflow() != 0) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java index 7aedcab583aa..6c5de06b3659 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java @@ -13,106 +13,56 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.state.KeyValuePairStateSerializer; import io.trino.operator.aggregation.state.KeyValuePairsState; -import io.trino.operator.aggregation.state.KeyValuePairsStateFactory; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.MapType; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.type.BlockTypeOperators; import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.util.Optional; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.type.TypeSignature.mapType; -import static io.trino.util.Reflection.methodHandle; -import static java.util.Objects.requireNonNull; - -public class MapAggregationFunction - extends SqlAggregationFunction +@AggregationFunction(value = "map_agg", isOrderSensitive = true) +@Description("Aggregates all the rows (key/value pairs) into a single map") +public final class MapAggregationFunction { - public static final String NAME = "map_agg"; - private static final MethodHandle INPUT_FUNCTION = methodHandle( - MapAggregationFunction.class, - "input", - Type.class, - BlockPositionEqual.class, - BlockPositionHashCode.class, - Type.class, - KeyValuePairsState.class, - Block.class, - Block.class, - int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(MapAggregationFunction.class, "combine", KeyValuePairsState.class, KeyValuePairsState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(MapAggregationFunction.class, "output", KeyValuePairsState.class, BlockBuilder.class); - - private final BlockTypeOperators blockTypeOperators; - - public MapAggregationFunction(BlockTypeOperators blockTypeOperators) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .comparableTypeParameter("K") - .typeVariable("V") - .returnType(mapType(new TypeSignature("K"), new TypeSignature("V"))) - .argumentType(new TypeSignature("K")) - .argumentType(new TypeSignature("V")) - .build()) - .argumentNullability(false, true) - .description("Aggregates all the rows (key/value pairs) into a single map") - .build(), - AggregationFunctionMetadata.builder() - .orderSensitive() - .intermediateType(mapType(new TypeSignature("K"), new TypeSignature("V"))) - .build()); - this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - MapType outputType = (MapType) boundSignature.getReturnType(); - Type keyType = outputType.getKeyType(); - BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType); - BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType); - - Type valueType = outputType.getValueType(); - KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(outputType, keyEqual, keyHashCode); - - return new AggregationMetadata( - MethodHandles.insertArguments(INPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType), - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION, - ImmutableList.of(new AccumulatorStateDescriptor<>( - KeyValuePairsState.class, - stateSerializer, - new KeyValuePairsStateFactory(keyType, valueType)))); - } + private MapAggregationFunction() {} + @InputFunction + @TypeParameter("K") + @TypeParameter("V") public static void input( - Type keyType, - BlockPositionEqual keyEqual, - BlockPositionHashCode keyHashCode, - Type valueType, - KeyValuePairsState state, - Block key, - Block value, - int position) + @TypeParameter("K") Type keyType, + @OperatorDependency( + operator = OperatorType.EQUAL, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) + BlockPositionEqual keyEqual, + @OperatorDependency( + operator = OperatorType.HASH_CODE, + argumentTypes = "K", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) + BlockPositionHashCode keyHashCode, + @TypeParameter("V") Type valueType, + @AggregationState({"K", "V"}) KeyValuePairsState state, + @BlockPosition @SqlType("K") Block key, + @NullablePosition @BlockPosition @SqlType("V") Block value, + @BlockIndex int position) { KeyValuePairs pairs = state.get(); if (pairs == null) { @@ -125,7 +75,10 @@ public static void input( state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize); } - public static void combine(KeyValuePairsState state, KeyValuePairsState otherState) + @CombineFunction + public static void combine( + @AggregationState({"K", "V"}) KeyValuePairsState state, + @AggregationState({"K", "V"}) KeyValuePairsState otherState) { if (state.get() != null && otherState.get() != null) { Block keys = otherState.get().getKeys(); @@ -142,7 +95,8 @@ else if (state.get() == null) { } } - public static void output(KeyValuePairsState state, BlockBuilder out) + @OutputFunction("map(K, V)") + public static void output(@AggregationState({"K", "V"}) KeyValuePairsState state, BlockBuilder out) { KeyValuePairs pairs = state.get(); if (pairs == null) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java index 54c0663ffc24..c37f40c38156 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java @@ -13,105 +13,52 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.state.KeyValuePairStateSerializer; import io.trino.operator.aggregation.state.KeyValuePairsState; -import io.trino.operator.aggregation.state.KeyValuePairsStateFactory; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.MapType; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.type.BlockTypeOperators; import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.util.Optional; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; -import static io.trino.spi.type.TypeSignature.mapType; -import static io.trino.util.Reflection.methodHandle; -import static java.util.Objects.requireNonNull; - -public class MapUnionAggregation - extends SqlAggregationFunction +@AggregationFunction("map_union") +@Description("Aggregate all the maps into a single map") +public final class MapUnionAggregation { - public static final String NAME = "map_union"; - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(MapUnionAggregation.class, "output", KeyValuePairsState.class, BlockBuilder.class); - private static final MethodHandle INPUT_FUNCTION = methodHandle(MapUnionAggregation.class, - "input", - Type.class, - BlockPositionEqual.class, - BlockPositionHashCode.class, - Type.class, - KeyValuePairsState.class, - Block.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(MapUnionAggregation.class, "combine", KeyValuePairsState.class, KeyValuePairsState.class); - - private final BlockTypeOperators blockTypeOperators; - - public MapUnionAggregation(BlockTypeOperators blockTypeOperators) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .comparableTypeParameter("K") - .typeVariable("V") - .returnType(mapType(new TypeSignature("K"), new TypeSignature("V"))) - .argumentType(mapType(new TypeSignature("K"), new TypeSignature("V"))) - .build()) - .description("Aggregate all the maps into a single map") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(mapType(new TypeSignature("K"), new TypeSignature("V"))) - .build()); - this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - MapType outputType = (MapType) boundSignature.getReturnType(); - Type keyType = outputType.getKeyType(); - BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType); - BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType); - - Type valueType = outputType.getValueType(); - - KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(outputType, keyEqual, keyHashCode); - - MethodHandle inputFunction = MethodHandles.insertArguments(INPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType); - inputFunction = normalizeInputMethod(inputFunction, boundSignature, STATE, INPUT_CHANNEL); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION, - ImmutableList.of(new AccumulatorStateDescriptor<>( - KeyValuePairsState.class, - stateSerializer, - new KeyValuePairsStateFactory(keyType, valueType)))); - } + private MapUnionAggregation() {} + @InputFunction + @TypeParameter("K") + @TypeParameter("V") public static void input( - Type keyType, - BlockPositionEqual keyEqual, - BlockPositionHashCode keyHashCode, - Type valueType, - KeyValuePairsState state, - Block value) + @TypeParameter("K") Type keyType, + @OperatorDependency( + operator = OperatorType.EQUAL, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) + BlockPositionEqual keyEqual, + @OperatorDependency( + operator = OperatorType.HASH_CODE, + argumentTypes = "K", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) + BlockPositionHashCode keyHashCode, + @TypeParameter("V") Type valueType, + @AggregationState({"K", "V"}) KeyValuePairsState state, + @SqlType("map(K,V)") Block value) { KeyValuePairs pairs = state.get(); if (pairs == null) { @@ -126,12 +73,16 @@ public static void input( state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize); } - public static void combine(KeyValuePairsState state, KeyValuePairsState otherState) + @CombineFunction + public static void combine( + @AggregationState({"K", "V"}) KeyValuePairsState state, + @AggregationState({"K", "V"}) KeyValuePairsState otherState) { MapAggregationFunction.combine(state, otherState); } - public static void output(KeyValuePairsState state, BlockBuilder out) + @OutputFunction("map(K, V)") + public static void output(@AggregationState({"K", "V"}) KeyValuePairsState state, BlockBuilder out) { MapAggregationFunction.output(state, out); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java index dedf892da785..b781a7358fa4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java @@ -13,15 +13,72 @@ */ package io.trino.operator.aggregation; -public class MaxAggregationFunction - extends AbstractMinMaxAggregationFunction +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; + +@AggregationFunction("max") +@Description("Returns the maximum value of the argument") +public final class MaxAggregationFunction { - private static final String NAME = "max"; + private MaxAggregationFunction() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"T", "T"}, + convention = @Convention(arguments = {BLOCK_POSITION, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("T") InOut state, + @BlockPosition @SqlType("T") Block block, + @BlockIndex int position) + throws Throwable + { + if (state.isNull() || ((long) compare.invokeExact(block, position, state)) > 0) { + state.set(block, position); + } + } - public static final MaxAggregationFunction MAX_AGGREGATION = new MaxAggregationFunction(); + @CombineFunction + public static void combine( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"T", "T"}, + convention = @Convention(arguments = {IN_OUT, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("T") InOut state, + @AggregationState("T") InOut otherState) + throws Throwable + { + if (state.isNull() || ((long) compare.invokeExact(otherState, state)) > 0) { + state.set(otherState); + } + } - public MaxAggregationFunction() + @OutputFunction("T") + public static void output(@AggregationState("T") InOut state, BlockBuilder out) { - super(NAME, false, "Returns the maximum value of the argument"); + state.get(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java new file mode 100644 index 000000000000..0c29d7421cfb --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; + +@AggregationFunction("max_by") +@Description("Returns the value of the first argument, associated with the maximum value of the second argument") +public final class MaxByAggregationFunction +{ + private MaxByAggregationFunction() {} + + @InputFunction + @TypeParameter("V") + @TypeParameter("K") + public static void input( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockPosition @SqlType("K") Block keyBlock, + @BlockIndex int position) + throws Throwable + { + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) > 0) { + keyState.set(keyBlock, position); + valueState.set(valueBlock, position); + } + } + + @CombineFunction + public static void combine( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {IN_OUT, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @AggregationState("K") InOut otherKeyState, + @AggregationState("V") InOut otherValueState) + throws Throwable + { + if (otherKeyState.isNull()) { + return; + } + if (keyState.isNull() || ((long) compare.invokeExact(otherKeyState, keyState)) > 0) { + keyState.set(otherKeyState); + valueState.set(otherValueState); + } + } + + @OutputFunction("V") + public static void output( + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + BlockBuilder out) + { + valueState.get(out); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java index 230bacca4d34..8d36e8294c17 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java @@ -13,90 +13,46 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; import io.airlift.stats.QuantileDigest; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; import io.trino.operator.aggregation.state.QuantileDigestState; -import io.trino.operator.aggregation.state.QuantileDigestStateFactory; -import io.trino.operator.aggregation.state.QuantileDigestStateSerializer; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; import io.trino.spi.function.InputFunction; -import io.trino.spi.type.QuantileDigestType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; - -import java.lang.invoke.MethodHandle; -import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import static io.trino.spi.type.StandardTypes.QDIGEST; -import static io.trino.spi.type.TypeSignature.parametricType; import static io.trino.util.MoreMath.nearlyEqual; -import static io.trino.util.Reflection.methodHandle; -@AggregationFunction("merge") +@AggregationFunction(value = "merge", isOrderSensitive = true) +@Description("Merges the input quantile digests into a single quantile digest") public final class MergeQuantileDigestFunction - extends SqlAggregationFunction { - public static final MergeQuantileDigestFunction MERGE = new MergeQuantileDigestFunction(); - public static final String NAME = "merge"; - private static final MethodHandle INPUT_FUNCTION = methodHandle(MergeQuantileDigestFunction.class, "input", Type.class, QuantileDigestState.class, Block.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(MergeQuantileDigestFunction.class, "combine", QuantileDigestState.class, QuantileDigestState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(MergeQuantileDigestFunction.class, "output", QuantileDigestStateSerializer.class, QuantileDigestState.class, BlockBuilder.class); - private static final double COMPARISON_EPSILON = 1E-6; - - public MergeQuantileDigestFunction() - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .comparableTypeParameter("T") - .returnType(parametricType(QDIGEST, new TypeSignature("T"))) - .argumentType(parametricType(QDIGEST, new TypeSignature("T"))) - .build()) - .description("Merges the input quantile digests into a single quantile digest") - .build(), - AggregationFunctionMetadata.builder() - .orderSensitive() - .intermediateType(parametricType(QDIGEST, new TypeSignature("T"))) - .build()); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - QuantileDigestType outputType = (QuantileDigestType) boundSignature.getReturnType(); - Type valueType = outputType.getValueType(); - QuantileDigestStateSerializer stateSerializer = new QuantileDigestStateSerializer(valueType); + private MergeQuantileDigestFunction() {} - return new AggregationMetadata( - INPUT_FUNCTION.bindTo(outputType), - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION.bindTo(stateSerializer), - ImmutableList.of(new AccumulatorStateDescriptor<>( - QuantileDigestState.class, - stateSerializer, - new QuantileDigestStateFactory()))); - } + private static final double COMPARISON_EPSILON = 1.0E-6; @InputFunction - public static void input(Type type, QuantileDigestState state, Block value, int index) + @TypeParameter("V") + public static void input( + @TypeParameter("V") Type type, + @AggregationState QuantileDigestState state, + @BlockPosition @SqlType("V") Block value, + @BlockIndex int index) { merge(state, new QuantileDigest(type.getSlice(value, index))); } @CombineFunction - public static void combine(QuantileDigestState state, QuantileDigestState otherState) + public static void combine(@AggregationState QuantileDigestState state, @AggregationState QuantileDigestState otherState) { merge(state, otherState.getQuantileDigest()); } @@ -122,8 +78,17 @@ private static void merge(QuantileDigestState state, QuantileDigest input) } } - public static void output(QuantileDigestStateSerializer serializer, QuantileDigestState state, BlockBuilder out) + @OutputFunction("qdigest(V)") + public static void output( + @TypeParameter("V") Type type, + @AggregationState QuantileDigestState state, + BlockBuilder out) { - serializer.serialize(state, out); + if (state.getQuantileDigest() == null) { + out.appendNull(); + } + else { + type.writeSlice(out, state.getQuantileDigest().serialize()); + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java index 270f209a39d0..839d56965972 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java @@ -13,15 +13,72 @@ */ package io.trino.operator.aggregation; -public class MinAggregationFunction - extends AbstractMinMaxAggregationFunction +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; + +@AggregationFunction("min") +@Description("Returns the minimum value of the argument") +public final class MinAggregationFunction { - private static final String NAME = "min"; + private MinAggregationFunction() {} + + @InputFunction + @TypeParameter("T") + public static void input( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"T", "T"}, + convention = @Convention(arguments = {BLOCK_POSITION, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("T") InOut state, + @BlockPosition @SqlType("T") Block block, + @BlockIndex int position) + throws Throwable + { + if (state.isNull() || ((long) compare.invokeExact(block, position, state)) < 0) { + state.set(block, position); + } + } - public static final MinAggregationFunction MIN_AGGREGATION = new MinAggregationFunction(); + @CombineFunction + public static void combine( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"T", "T"}, + convention = @Convention(arguments = {IN_OUT, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("T") InOut state, + @AggregationState("T") InOut otherState) + throws Throwable + { + if (state.isNull() || ((long) compare.invokeExact(otherState, state)) < 0) { + state.set(otherState); + } + } - public MinAggregationFunction() + @OutputFunction("T") + public static void output(@AggregationState("T") InOut state, BlockBuilder out) { - super(NAME, true, "Returns the minimum value of the argument"); + state.get(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java new file mode 100644 index 000000000000..b26648b0fa4b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; + +@AggregationFunction("min_by") +@Description("Returns the value of the first argument, associated with the minimum value of the second argument") +public final class MinByAggregationFunction +{ + private MinByAggregationFunction() {} + + @InputFunction + @TypeParameter("V") + @TypeParameter("K") + public static void input( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockPosition @SqlType("K") Block keyBlock, + @BlockIndex int position) + throws Throwable + { + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) < 0) { + keyState.set(keyBlock, position); + valueState.set(valueBlock, position); + } + } + + @CombineFunction + public static void combine( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {IN_OUT, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @AggregationState("K") InOut otherKeyState, + @AggregationState("V") InOut otherValueState) + throws Throwable + { + if (otherKeyState.isNull()) { + return; + } + if (keyState.isNull() || ((long) compare.invokeExact(otherKeyState, keyState)) < 0) { + keyState.set(otherKeyState); + valueState.set(otherValueState); + } + } + + @OutputFunction("V") + public static void output( + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + BlockBuilder out) + { + valueState.get(out); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java index 031441440220..5f7a6f408c1a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java @@ -28,11 +28,11 @@ import io.trino.metadata.SignatureBinder; import io.trino.metadata.SqlAggregationFunction; import io.trino.operator.ParametricImplementationsGroup; +import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.annotations.ImplementationDependency; import io.trino.spi.TrinoException; -import io.trino.spi.function.AccumulatorState; import java.lang.invoke.MethodHandle; import java.util.Collection; @@ -41,11 +41,9 @@ import java.util.StringJoiner; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.operator.ParametricFunctionHelpers.bindDependencies; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; -import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory; -import static io.trino.operator.aggregation.state.StateCompiler.generateStateSerializer; -import static io.trino.operator.aggregation.state.StateCompiler.getSerializedType; import static io.trino.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL; import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static java.lang.String.format; @@ -55,18 +53,18 @@ public class ParametricAggregation extends SqlAggregationFunction { private final ParametricImplementationsGroup implementations; - private final Class stateClass; + private final List> stateDetails; public ParametricAggregation( Signature signature, AggregationHeader details, - Class stateClass, + List> stateDetails, ParametricImplementationsGroup implementations) { super( createFunctionMetadata(signature, details, implementations.getFunctionNullability()), - createAggregationFunctionMetadata(details, stateClass)); - this.stateClass = requireNonNull(stateClass, "stateClass is null"); + createAggregationFunctionMetadata(details, stateDetails)); + this.stateDetails = ImmutableList.copyOf(requireNonNull(stateDetails, "stateDetails is null")); checkArgument(implementations.getFunctionNullability().isReturnNullable(), "currently aggregates are required to be nullable"); this.implementations = requireNonNull(implementations, "implementations is null"); } @@ -99,14 +97,16 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Aggr return functionMetadata.build(); } - private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, Class stateClass) + private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, List> stateDetails) { AggregationFunctionMetadataBuilder builder = AggregationFunctionMetadata.builder(); if (details.isOrderSensitive()) { builder.orderSensitive(); } if (details.isDecomposable()) { - builder.intermediateType(getSerializedType(stateClass).getTypeSignature()); + for (AccumulatorStateDetails stateDetail : stateDetails) { + builder.intermediateType(stateDetail.getSerializedType()); + } } return builder.build(); } @@ -118,6 +118,11 @@ public FunctionDependencyDeclaration getFunctionDependencies() declareDependencies(builder, implementations.getExactImplementations().values()); declareDependencies(builder, implementations.getSpecializedImplementations()); declareDependencies(builder, implementations.getGenericImplementations()); + for (AccumulatorStateDetails stateDetail : stateDetails) { + for (ImplementationDependency dependency : stateDetail.getDependencies()) { + dependency.declareDependencies(builder); + } + } return builder.build(); } @@ -143,11 +148,13 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature); // Build state factory and serializer - AccumulatorStateDescriptor accumulatorStateDescriptor = generateAccumulatorStateDescriptor(stateClass); - - // Bind provided dependencies to aggregation method handlers FunctionMetadata metadata = getFunctionMetadata(); FunctionBinding functionBinding = SignatureBinder.bindFunction(metadata.getFunctionId(), metadata.getSignature(), boundSignature); + List> accumulatorStateDescriptors = stateDetails.stream() + .map(state -> state.createAccumulatorStateDescriptor(functionBinding, functionDependencies)) + .collect(toImmutableList()); + + // Bind provided dependencies to aggregation method handlers MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), functionBinding, functionDependencies); Optional removeInputHandle = concreteImplementation.getRemoveInputFunction().map( removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), functionBinding, functionDependencies)); @@ -172,20 +179,13 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep removeInputHandle, combineHandle, outputHandle, - ImmutableList.of(accumulatorStateDescriptor)); - } - - private static AccumulatorStateDescriptor generateAccumulatorStateDescriptor(Class stateClass) - { - return new AccumulatorStateDescriptor<>( - stateClass, - generateStateSerializer(stateClass), - generateStateFactory(stateClass)); + accumulatorStateDescriptors); } - public Class getStateClass() + @VisibleForTesting + public List> getStateDetails() { - return stateClass; + return stateDetails; } @VisibleForTesting diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/QuantileDigestAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/QuantileDigestAggregationFunction.java index 4a2b596175e1..49bf0f4c9737 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/QuantileDigestAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/QuantileDigestAggregationFunction.java @@ -13,32 +13,19 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; import io.airlift.stats.QuantileDigest; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; import io.trino.operator.aggregation.state.QuantileDigestState; -import io.trino.operator.aggregation.state.QuantileDigestStateFactory; -import io.trino.operator.aggregation.state.QuantileDigestStateSerializer; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; import io.trino.spi.type.QuantileDigestType; -import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import java.lang.invoke.MethodHandle; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; -import static io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import static io.trino.operator.aggregation.FloatingPointBitsConverterUtil.doubleToSortableLong; import static io.trino.operator.aggregation.FloatingPointBitsConverterUtil.floatToSortableInt; import static io.trino.operator.scalar.QuantileDigestFunctions.DEFAULT_ACCURACY; @@ -47,131 +34,160 @@ import static io.trino.operator.scalar.QuantileDigestFunctions.verifyWeight; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.StandardTypes.QDIGEST; -import static io.trino.spi.type.TypeSignature.parametricType; -import static io.trino.util.Reflection.methodHandle; +import static io.trino.spi.type.RealType.REAL; import static java.lang.Float.intBitsToFloat; -import static java.lang.String.format; -import static java.lang.invoke.MethodHandles.insertArguments; public final class QuantileDigestAggregationFunction - extends SqlAggregationFunction { - public static final QuantileDigestAggregationFunction QDIGEST_AGG = new QuantileDigestAggregationFunction(new TypeSignature("V")); - public static final QuantileDigestAggregationFunction QDIGEST_AGG_WITH_WEIGHT = new QuantileDigestAggregationFunction(new TypeSignature("V"), BIGINT.getTypeSignature()); - public static final QuantileDigestAggregationFunction QDIGEST_AGG_WITH_WEIGHT_AND_ERROR = new QuantileDigestAggregationFunction(new TypeSignature("V"), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature()); - public static final String NAME = "qdigest_agg"; - - private static final MethodHandle INPUT_DOUBLE = methodHandle(QuantileDigestAggregationFunction.class, "inputDouble", QuantileDigestState.class, double.class, long.class, double.class); - private static final MethodHandle INPUT_REAL = methodHandle(QuantileDigestAggregationFunction.class, "inputReal", QuantileDigestState.class, long.class, long.class, double.class); - private static final MethodHandle INPUT_BIGINT = methodHandle(QuantileDigestAggregationFunction.class, "inputBigint", QuantileDigestState.class, long.class, long.class, double.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(QuantileDigestAggregationFunction.class, "combineState", QuantileDigestState.class, QuantileDigestState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(QuantileDigestAggregationFunction.class, "evaluateFinal", QuantileDigestStateSerializer.class, QuantileDigestState.class, BlockBuilder.class); - - private QuantileDigestAggregationFunction(TypeSignature... typeSignatures) + @AggregationFunction(value = "qdigest_agg", isOrderSensitive = true) + @Description("Returns a qdigest from the set of doubles") + public static final class DoubleQuantileDigestAggregationFunction { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .comparableTypeParameter("V") - .returnType(parametricType(QDIGEST, new TypeSignature("V"))) - .argumentTypes(ImmutableList.copyOf(typeSignatures)) - .build()) - .description("Returns a qdigest from the set of reals, bigints or doubles") - .build(), - AggregationFunctionMetadata.builder() - .orderSensitive() - .intermediateType(parametricType(QDIGEST, new TypeSignature("V"))) - .build()); - } + private static final QuantileDigestType OUTPUT_TYPE = new QuantileDigestType(DOUBLE); - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - QuantileDigestType outputType = (QuantileDigestType) boundSignature.getReturnType(); - Type valueType = outputType.getValueType(); - int arity = boundSignature.getArity(); - - QuantileDigestStateSerializer stateSerializer = new QuantileDigestStateSerializer(valueType); - - MethodHandle inputFunction = getMethodHandle(valueType, arity); - inputFunction = normalizeInputMethod(inputFunction, boundSignature, ImmutableList.builder() - .add(STATE) - .addAll(getInputTypes(valueType, arity).stream().map(ignored -> INPUT_CHANNEL).collect(Collectors.toList())) - .build()); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION.bindTo(stateSerializer), - ImmutableList.of(new AccumulatorStateDescriptor<>( - QuantileDigestState.class, - stateSerializer, - new QuantileDigestStateFactory()))); - } + private DoubleQuantileDigestAggregationFunction() {} - private static List getInputTypes(Type valueType, int arity) - { - switch (arity) { - case 1: - // weight and accuracy unspecified - return ImmutableList.of(valueType); - case 2: - // weight specified, accuracy unspecified - return ImmutableList.of(valueType, BIGINT); - case 3: - // weight and accuracy specified - return ImmutableList.of(valueType, BIGINT, DOUBLE); - default: - throw new IllegalArgumentException(format("Unsupported number of arguments: %s", arity)); + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("DOUBLE") double value) + { + input(state, value, DEFAULT_WEIGHT, DEFAULT_ACCURACY); } - } - private static MethodHandle getMethodHandle(Type valueType, int arity) - { - MethodHandle inputFunction; - switch (valueType.getDisplayName()) { - case StandardTypes.DOUBLE: - inputFunction = INPUT_DOUBLE; - break; - case StandardTypes.REAL: - inputFunction = INPUT_REAL; - break; - case StandardTypes.BIGINT: - inputFunction = INPUT_BIGINT; - break; - default: - throw new IllegalArgumentException(format("Unsupported type %s supplied", valueType.getDisplayName())); - } - - switch (arity) { - case 1: - // weight and accuracy unspecified - return insertArguments(inputFunction, 2, DEFAULT_WEIGHT, DEFAULT_ACCURACY); - case 2: - // weight specified, accuracy unspecified - return insertArguments(inputFunction, 3, DEFAULT_ACCURACY); - case 3: - // weight and accuracy specified - return inputFunction; - default: - throw new IllegalArgumentException(format("Unsupported number of arguments: %s", arity)); + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("DOUBLE") double value, + @SqlType("BIGINT") long weight) + { + input(state, value, weight, DEFAULT_ACCURACY); + } + + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("DOUBLE") double value, + @SqlType("BIGINT") long weight, + @SqlType("DOUBLE") double accuracy) + { + internalInput(state, doubleToSortableLong(value), weight, accuracy); + } + + @CombineFunction + public static void combine(@AggregationState QuantileDigestState state, @AggregationState QuantileDigestState otherState) + { + internalCombine(state, otherState); + } + + @OutputFunction("qdigest(DOUBLE)") + public static void output(@AggregationState QuantileDigestState state, BlockBuilder out) + { + internalOutput(OUTPUT_TYPE, state, out); } } - public static void inputDouble(QuantileDigestState state, double value, long weight, double accuracy) + @AggregationFunction(value = "qdigest_agg", isOrderSensitive = true) + @Description("Returns a qdigest from the set of reals") + public static final class RealQuantileDigestAggregationFunction { - inputBigint(state, doubleToSortableLong(value), weight, accuracy); + private static final QuantileDigestType OUTPUT_TYPE = new QuantileDigestType(REAL); + + private RealQuantileDigestAggregationFunction() {} + + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("REAL") long value) + { + input(state, value, DEFAULT_WEIGHT, DEFAULT_ACCURACY); + } + + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("REAL") long value, + @SqlType("BIGINT") long weight) + { + input(state, value, weight, DEFAULT_ACCURACY); + } + + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("REAL") long value, + @SqlType("BIGINT") long weight, + @SqlType("DOUBLE") double accuracy) + { + internalInput(state, floatToSortableInt(intBitsToFloat((int) value)), weight, accuracy); + } + + @CombineFunction + public static void combine(@AggregationState QuantileDigestState state, @AggregationState QuantileDigestState otherState) + { + internalCombine(state, otherState); + } + + @OutputFunction("qdigest(REAL)") + public static void output(@AggregationState QuantileDigestState state, BlockBuilder out) + { + internalOutput(OUTPUT_TYPE, state, out); + } } - public static void inputReal(QuantileDigestState state, long value, long weight, double accuracy) + @AggregationFunction(value = "qdigest_agg", isOrderSensitive = true) + @Description("Returns a qdigest from the set of bigints") + public static final class BigintQuantileDigestAggregationFunction { - inputBigint(state, floatToSortableInt(intBitsToFloat((int) value)), weight, accuracy); + private static final QuantileDigestType OUTPUT_TYPE = new QuantileDigestType(BIGINT); + + private BigintQuantileDigestAggregationFunction() {} + + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("BIGINT") long value) + { + input(state, value, DEFAULT_WEIGHT, DEFAULT_ACCURACY); + } + + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("BIGINT") long value, + @SqlType("BIGINT") long weight) + { + input(state, value, weight, DEFAULT_ACCURACY); + } + + @InputFunction + public static void input( + @AggregationState QuantileDigestState state, + @SqlType("BIGINT") long value, + @SqlType("BIGINT") long weight, + @SqlType("DOUBLE") double accuracy) + { + internalInput(state, value, weight, accuracy); + } + + @CombineFunction + public static void combine(@AggregationState QuantileDigestState state, @AggregationState QuantileDigestState otherState) + { + internalCombine(state, otherState); + } + + @OutputFunction("qdigest(BIGINT)") + public static void output(@AggregationState QuantileDigestState state, BlockBuilder out) + { + internalOutput(OUTPUT_TYPE, state, out); + } } - public static void inputBigint(QuantileDigestState state, long value, long weight, double accuracy) + private static void internalInput( + QuantileDigestState state, + long value, + long weight, + double accuracy) { QuantileDigest qdigest = getOrCreateQuantileDigest(state, verifyAccuracy(accuracy)); state.addMemoryUsage(-qdigest.estimatedInMemorySizeInBytes()); @@ -190,7 +206,7 @@ private static QuantileDigest getOrCreateQuantileDigest(QuantileDigestState stat return qdigest; } - public static void combineState(QuantileDigestState state, QuantileDigestState otherState) + private static void internalCombine(QuantileDigestState state, QuantileDigestState otherState) { QuantileDigest input = otherState.getQuantileDigest(); @@ -206,8 +222,13 @@ public static void combineState(QuantileDigestState state, QuantileDigestState o } } - public static void evaluateFinal(QuantileDigestStateSerializer serializer, QuantileDigestState state, BlockBuilder out) + private static void internalOutput(Type type, QuantileDigestState state, BlockBuilder out) { - serializer.serialize(state, out); + if (state.getQuantileDigest() == null) { + out.appendNull(); + } + else { + type.writeSlice(out, state.getQuantileDigest().serialize()); + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java index 119c3d9b1c8c..3f0474de4f8d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java @@ -13,106 +13,64 @@ */ package io.trino.operator.aggregation; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.aggregation.state.DoubleState; import io.trino.operator.aggregation.state.LongState; -import io.trino.operator.aggregation.state.StateCompiler; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.RemoveInputFunction; +import io.trino.spi.function.SqlType; -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.RealType.REAL; -import static io.trino.util.Reflection.methodHandle; import static java.lang.Float.floatToIntBits; import static java.lang.Float.intBitsToFloat; -public class RealAverageAggregation - extends SqlAggregationFunction +@AggregationFunction("avg") +@Description("Returns the average value of the argument") +public final class RealAverageAggregation { - public static final RealAverageAggregation REAL_AVERAGE_AGGREGATION = new RealAverageAggregation(); - private static final String NAME = "avg"; - - private static final MethodHandle INPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "input", LongState.class, DoubleState.class, long.class); - private static final MethodHandle REMOVE_INPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "removeInput", LongState.class, DoubleState.class, long.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(RealAverageAggregation.class, "combine", LongState.class, DoubleState.class, LongState.class, DoubleState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "output", LongState.class, DoubleState.class, BlockBuilder.class); - - protected RealAverageAggregation() - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .returnType(REAL) - .argumentType(REAL) - .build()) - .description("Returns the average value of the argument") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(BIGINT) - .intermediateType(DOUBLE) - .build()); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Class longStateInterface = LongState.class; - Class doubleStateInterface = DoubleState.class; - AccumulatorStateSerializer longStateSerializer = StateCompiler.generateStateSerializer(longStateInterface); - AccumulatorStateSerializer doubleStateSerializer = StateCompiler.generateStateSerializer(doubleStateInterface); - - MethodHandle inputFunction = normalizeInputMethod(INPUT_FUNCTION, boundSignature, STATE, STATE, INPUT_CHANNEL); - MethodHandle removeFunction = normalizeInputMethod(REMOVE_INPUT_FUNCTION, boundSignature, STATE, STATE, INPUT_CHANNEL); - - return new AggregationMetadata( - inputFunction, - Optional.of(removeFunction), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION, - ImmutableList.of( - new AccumulatorStateDescriptor<>( - longStateInterface, - longStateSerializer, - StateCompiler.generateStateFactory(longStateInterface)), - new AccumulatorStateDescriptor<>( - doubleStateInterface, - doubleStateSerializer, - StateCompiler.generateStateFactory(doubleStateInterface)))); - } + private RealAverageAggregation() {} - public static void input(LongState count, DoubleState sum, long value) + @InputFunction + public static void input( + @AggregationState LongState count, + @AggregationState DoubleState sum, + @SqlType("REAL") long value) { count.setValue(count.getValue() + 1); sum.setValue(sum.getValue() + intBitsToFloat((int) value)); } - public static void removeInput(LongState count, DoubleState sum, long value) + @RemoveInputFunction + public static void removeInput( + @AggregationState LongState count, + @AggregationState DoubleState sum, + @SqlType("REAL") long value) { count.setValue(count.getValue() - 1); sum.setValue(sum.getValue() - intBitsToFloat((int) value)); } - public static void combine(LongState count, DoubleState sum, LongState otherCount, DoubleState otherSum) + @CombineFunction + public static void combine( + @AggregationState LongState count, + @AggregationState DoubleState sum, + @AggregationState LongState otherCount, + @AggregationState DoubleState otherSum) { count.setValue(count.getValue() + otherCount.getValue()); sum.setValue(sum.getValue() + otherSum.getValue()); } - public static void output(LongState count, DoubleState sum, BlockBuilder out) + @OutputFunction("REAL") + public static void output( + @AggregationState LongState count, + @AggregationState DoubleState sum, + BlockBuilder out) { if (count.getValue() == 0) { out.appendNull(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypeSignatureMapping.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/TypeSignatureMapping.java new file mode 100644 index 000000000000..0aa33834f994 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/TypeSignatureMapping.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSortedMap; +import io.trino.operator.annotations.CastImplementationDependency; +import io.trino.operator.annotations.FunctionImplementationDependency; +import io.trino.operator.annotations.ImplementationDependency; +import io.trino.operator.annotations.LiteralImplementationDependency; +import io.trino.operator.annotations.OperatorImplementationDependency; +import io.trino.operator.annotations.TypeImplementationDependency; +import io.trino.spi.type.NamedTypeSignature; +import io.trino.spi.type.ParameterKind; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.type.TypeSignatureParameter; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; + +class TypeSignatureMapping +{ + private final Map mapping; + + public TypeSignatureMapping(Map mapping) + { + this.mapping = ImmutableSortedMap.orderedBy(String.CASE_INSENSITIVE_ORDER) + .putAll(mapping) + .build(); + } + + public Set getTypeParameters() + { + return ImmutableSet.copyOf(mapping.keySet()); + } + + public ImplementationDependency mapTypes(ImplementationDependency dependency) + { + if (mapping.isEmpty()) { + return dependency; + } + if (dependency instanceof TypeImplementationDependency) { + TypeImplementationDependency typeDependency = (TypeImplementationDependency) dependency; + return new TypeImplementationDependency(mapTypeSignature(typeDependency.getSignature())); + } + if (dependency instanceof LiteralImplementationDependency) { + return dependency; + } + if (dependency instanceof FunctionImplementationDependency) { + FunctionImplementationDependency functionDependency = (FunctionImplementationDependency) dependency; + return new FunctionImplementationDependency( + functionDependency.getFullyQualifiedName(), + functionDependency.getArgumentTypes().stream() + .map(this::mapTypeSignature) + .collect(toImmutableList()), + functionDependency.getInvocationConvention(), + functionDependency.getType()); + } + if (dependency instanceof OperatorImplementationDependency) { + OperatorImplementationDependency operatorDependency = (OperatorImplementationDependency) dependency; + return new OperatorImplementationDependency( + operatorDependency.getOperator(), + operatorDependency.getArgumentTypes().stream() + .map(this::mapTypeSignature) + .collect(toImmutableList()), + operatorDependency.getInvocationConvention(), + operatorDependency.getType()); + } + if (dependency instanceof CastImplementationDependency) { + CastImplementationDependency castDependency = (CastImplementationDependency) dependency; + return new CastImplementationDependency( + mapTypeSignature(castDependency.getFromType()), + mapTypeSignature(castDependency.getToType()), + castDependency.getInvocationConvention(), + castDependency.getType()); + } + throw new IllegalArgumentException("Unsupported dependency " + dependency); + } + + public TypeSignature mapTypeSignature(TypeSignature typeSignature) + { + if (mapping.isEmpty()) { + return typeSignature; + } + if (mapping.containsKey(typeSignature.getBase())) { + checkArgument(typeSignature.getParameters().isEmpty(), "Type variable can not have type parameters: %s", typeSignature); + return new TypeSignature(mapping.get(typeSignature.getBase())); + } + return new TypeSignature( + typeSignature.getBase(), + typeSignature.getParameters().stream() + .map(this::mapTypeSignatureParameter) + .collect(toImmutableList())); + } + + private TypeSignatureParameter mapTypeSignatureParameter(TypeSignatureParameter parameter) + { + if (parameter.getKind() == ParameterKind.TYPE) { + return TypeSignatureParameter.typeParameter(mapTypeSignature(parameter.getTypeSignature())); + } + if (parameter.getKind() == ParameterKind.NAMED_TYPE) { + NamedTypeSignature namedTypeSignature = parameter.getNamedTypeSignature(); + return TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature( + namedTypeSignature.getFieldName(), + mapTypeSignature(namedTypeSignature.getTypeSignature()))); + } + return parameter; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java index 1fed573ec31f..a4c4b295f292 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java @@ -13,86 +13,50 @@ */ package io.trino.operator.aggregation.arrayagg; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; +import io.trino.operator.aggregation.NullablePosition; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static io.trino.spi.type.TypeSignature.arrayType; -import static io.trino.util.Reflection.methodHandle; - -public class ArrayAggregationFunction - extends SqlAggregationFunction +@AggregationFunction(value = "array_agg", isOrderSensitive = true) +@Description("return an array of values") +public final class ArrayAggregationFunction { - public static final ArrayAggregationFunction ARRAY_AGG = new ArrayAggregationFunction(); - private static final String NAME = "array_agg"; - private static final MethodHandle INPUT_FUNCTION = methodHandle(ArrayAggregationFunction.class, "input", Type.class, ArrayAggregationState.class, Block.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(ArrayAggregationFunction.class, "combine", Type.class, ArrayAggregationState.class, ArrayAggregationState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ArrayAggregationFunction.class, "output", Type.class, ArrayAggregationState.class, BlockBuilder.class); - - private ArrayAggregationFunction() - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .typeVariable("T") - .returnType(arrayType(new TypeSignature("T"))) - .argumentType(new TypeSignature("T")) - .build()) - .argumentNullability(true) - .description("return an array of values") - .build(), - AggregationFunctionMetadata.builder() - .orderSensitive() - .intermediateType(arrayType(new TypeSignature("T"))) - .build()); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Type type = boundSignature.getArgumentTypes().get(0); - ArrayAggregationStateSerializer stateSerializer = new ArrayAggregationStateSerializer(type); - ArrayAggregationStateFactory stateFactory = new ArrayAggregationStateFactory(type); - - MethodHandle inputFunction = INPUT_FUNCTION.bindTo(type); - MethodHandle combineFunction = COMBINE_FUNCTION.bindTo(type); - MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(type); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(combineFunction), - outputFunction, - ImmutableList.of(new AccumulatorStateDescriptor<>( - ArrayAggregationState.class, - stateSerializer, - stateFactory))); - } + private ArrayAggregationFunction() {} - public static void input(Type type, ArrayAggregationState state, Block value, int position) + @InputFunction + @TypeParameter("T") + public static void input( + @AggregationState("T") ArrayAggregationState state, + @NullablePosition @BlockPosition @SqlType("T") Block value, + @BlockIndex int position) { state.add(value, position); } - public static void combine(Type type, ArrayAggregationState state, ArrayAggregationState otherState) + @CombineFunction + public static void combine( + @AggregationState("T") ArrayAggregationState state, + @AggregationState("T") ArrayAggregationState otherState) { state.merge(otherState); } - public static void output(Type elementType, ArrayAggregationState state, BlockBuilder out) + @OutputFunction("array(T)") + public static void output( + @TypeParameter("T") Type elementType, + @AggregationState("T") ArrayAggregationState state, + BlockBuilder out) { if (state.isEmpty()) { out.appendNull(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java index 88fa53913346..2b4fce3fa318 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java @@ -15,7 +15,13 @@ import io.trino.spi.block.Block; import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; +@AccumulatorStateMetadata( + stateFactoryClass = ArrayAggregationStateFactory.class, + stateSerializerClass = ArrayAggregationStateSerializer.class, + typeParameters = "T", + serializedType = "ARRAY(T)") public interface ArrayAggregationState extends AccumulatorState { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java index 01bfa9de1be5..c840b57cc69a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java @@ -14,6 +14,7 @@ package io.trino.operator.aggregation.arrayagg; import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; public class ArrayAggregationStateFactory @@ -21,7 +22,7 @@ public class ArrayAggregationStateFactory { private final Type type; - public ArrayAggregationStateFactory(Type type) + public ArrayAggregationStateFactory(@TypeParameter("T") Type type) { this.type = type; } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateSerializer.java index 4bf5b7bd4e17..5b9d93a0c7e2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateSerializer.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; @@ -25,7 +26,7 @@ public class ArrayAggregationStateSerializer private final Type elementType; private final Type arrayType; - public ArrayAggregationStateSerializer(Type elementType) + public ArrayAggregationStateSerializer(@TypeParameter("T") Type elementType) { this.elementType = elementType; this.arrayType = new ArrayType(elementType); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java index f47d5c44ef5e..a9aae24b84ea 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java @@ -37,12 +37,12 @@ public class GroupedHistogramState private TypedHistogram typedHistogram; private long size; - public GroupedHistogramState(Type keyType, BlockPositionEqual equalOperator, BlockPositionHashCode hashCodeOperator, int expectedEntriesCount) + public GroupedHistogramState(Type type, BlockPositionEqual equalOperator, BlockPositionHashCode hashCodeOperator, int expectedEntriesCount) { - this.type = requireNonNull(keyType, "keyType is null"); + this.type = requireNonNull(type, "type is null"); this.equalOperator = requireNonNull(equalOperator, "equalOperator is null"); this.hashCodeOperator = requireNonNull(hashCodeOperator, "hashCodeOperator is null"); - typedHistogram = new GroupedTypedHistogram(keyType, equalOperator, hashCodeOperator, expectedEntriesCount); + typedHistogram = new GroupedTypedHistogram(type, equalOperator, hashCodeOperator, expectedEntriesCount); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java index 30f65f12db57..00e638cf2532 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java @@ -13,82 +13,35 @@ */ package io.trino.operator.aggregation.histogram; -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.type.BlockTypeOperators; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; -import io.trino.type.BlockTypeOperators.BlockPositionHashCode; -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.TypeSignature.mapType; -import static io.trino.util.Reflection.methodHandle; import static java.util.Objects.requireNonNull; -public class Histogram - extends SqlAggregationFunction +@AggregationFunction("histogram") +@Description("Count the number of times each value occurs") +public final class Histogram { - public static final String NAME = "histogram"; - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(Histogram.class, "output", Type.class, HistogramState.class, BlockBuilder.class); - private static final MethodHandle INPUT_FUNCTION = methodHandle(Histogram.class, "input", Type.class, HistogramState.class, Block.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(Histogram.class, "combine", HistogramState.class, HistogramState.class); - - public static final int EXPECTED_SIZE_FOR_HASHING = 10; - private final BlockTypeOperators blockTypeOperators; - - public Histogram(BlockTypeOperators blockTypeOperators) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .comparableTypeParameter("K") - .returnType(mapType(new TypeSignature("K"), BIGINT.getTypeSignature())) - .argumentType(new TypeSignature("K")) - .build()) - .description("Count the number of times each value occurs") - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(mapType(new TypeSignature("K"), BIGINT.getTypeSignature())) - .build()); - this.blockTypeOperators = blockTypeOperators; - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Type keyType = boundSignature.getArgumentTypes().get(0); - BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType); - BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType); - Type outputType = boundSignature.getReturnType(); - HistogramStateSerializer stateSerializer = new HistogramStateSerializer(outputType); - MethodHandle inputFunction = INPUT_FUNCTION.bindTo(keyType); - MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(outputType); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - outputFunction, - ImmutableList.of(new AccumulatorStateDescriptor<>( - HistogramState.class, - stateSerializer, - new HistogramStateFactory(keyType, keyEqual, keyHashCode, EXPECTED_SIZE_FOR_HASHING)))); - } + private Histogram() {} - public static void input(Type type, HistogramState state, Block key, int position) + @InputFunction + @TypeParameter("T") + public static void input( + @TypeParameter("T") Type type, + @AggregationState("T") HistogramState state, + @BlockPosition @SqlType("T") Block key, + @BlockIndex int position) { TypedHistogram typedHistogram = state.get(); long startSize = typedHistogram.getEstimatedSize(); @@ -96,7 +49,8 @@ public static void input(Type type, HistogramState state, Block key, int positio state.addMemoryUsage(typedHistogram.getEstimatedSize() - startSize); } - public static void combine(HistogramState state, HistogramState otherState) + @CombineFunction + public static void combine(@AggregationState("T") HistogramState state, @AggregationState("T") HistogramState otherState) { // NOTE: state = current merged state; otherState = scratchState (new data to be added) // for grouped histograms and single histograms, we have a single histogram object. In neither case, can otherState.get() return null. @@ -110,7 +64,8 @@ public static void combine(HistogramState state, HistogramState otherState) state.addMemoryUsage(typedHistogram.getEstimatedSize() - startSize); } - public static void output(Type type, HistogramState state, BlockBuilder out) + @OutputFunction("map(T, BIGINT)") + public static void output(@TypeParameter("T") Type type, @AggregationState("T") HistogramState state, BlockBuilder out) { TypedHistogram typedHistogram = state.get(); typedHistogram.serialize(out); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java index 2d9f451a0e6b..88d882b143a2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java @@ -17,7 +17,11 @@ import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; -@AccumulatorStateMetadata(stateFactoryClass = HistogramStateFactory.class, stateSerializerClass = HistogramStateSerializer.class) +@AccumulatorStateMetadata( + stateFactoryClass = HistogramStateFactory.class, + stateSerializerClass = HistogramStateSerializer.class, + typeParameters = "T", + serializedType = "map(T, BIGINT)") public interface HistogramState extends AccumulatorState { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java index 42751413b00f..b54ad9106d83 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java @@ -14,41 +14,55 @@ package io.trino.operator.aggregation.histogram; import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.Convention; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static java.util.Objects.requireNonNull; public class HistogramStateFactory implements AccumulatorStateFactory { - private final Type keyType; + public static final int EXPECTED_SIZE_FOR_HASHING = 10; + + private final Type type; private final BlockPositionEqual equalOperator; private final BlockPositionHashCode hashCodeOperator; - private final int expectedEntriesCount; public HistogramStateFactory( - Type keyType, - BlockPositionEqual equalOperator, - BlockPositionHashCode hashCodeOperator, - int expectedEntriesCount) + @TypeParameter("T") Type type, + @OperatorDependency( + operator = OperatorType.EQUAL, + argumentTypes = {"T", "T"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) + BlockPositionEqual equalOperator, + @OperatorDependency( + operator = OperatorType.HASH_CODE, + argumentTypes = "T", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) + BlockPositionHashCode hashCodeOperator) { - this.keyType = requireNonNull(keyType, "keyType is null"); + this.type = requireNonNull(type, "type is null"); this.equalOperator = requireNonNull(equalOperator, "equalOperator is null"); this.hashCodeOperator = requireNonNull(hashCodeOperator, "hashCodeOperator is null"); - this.expectedEntriesCount = expectedEntriesCount; } @Override public HistogramState createSingleState() { - return new SingleHistogramState(keyType, equalOperator, hashCodeOperator, expectedEntriesCount); + return new SingleHistogramState(type, equalOperator, hashCodeOperator, EXPECTED_SIZE_FOR_HASHING); } @Override public HistogramState createGroupedState() { - return new GroupedHistogramState(keyType, equalOperator, hashCodeOperator, expectedEntriesCount); + return new GroupedHistogramState(type, equalOperator, hashCodeOperator, EXPECTED_SIZE_FOR_HASHING); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateSerializer.java index d0489c0bc9b8..d59781e8e805 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateSerializer.java @@ -16,16 +16,17 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import static io.trino.operator.aggregation.histogram.Histogram.EXPECTED_SIZE_FOR_HASHING; +import static io.trino.operator.aggregation.histogram.HistogramStateFactory.EXPECTED_SIZE_FOR_HASHING; public class HistogramStateSerializer implements AccumulatorStateSerializer { private final Type serializedType; - public HistogramStateSerializer(Type serializedType) + public HistogramStateSerializer(@TypeParameter("map(T, BIGINT)") Type serializedType) { this.serializedType = serializedType; } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java index dcb57481ebb0..b1511bff3505 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java @@ -18,7 +18,8 @@ import io.trino.operator.aggregation.AbstractGroupCollectionAggregationState; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; -import io.trino.spi.type.Type; + +import static io.trino.spi.type.VarcharType.VARCHAR; public final class GroupListaggAggregationState extends AbstractGroupCollectionAggregationState @@ -32,9 +33,9 @@ public final class GroupListaggAggregationState private Slice overflowFiller; private boolean showOverflowEntryCount; - public GroupListaggAggregationState(Type valueType) + public GroupListaggAggregationState() { - super(PageBuilder.withMaxPageSize(MAX_BLOCK_SIZE, ImmutableList.of(valueType))); + super(PageBuilder.withMaxPageSize(MAX_BLOCK_SIZE, ImmutableList.of(VARCHAR))); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java index a071fe973d93..8355dd9b5dda 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java @@ -14,114 +14,44 @@ package io.trino.operator.aggregation.listagg; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorStateFactory; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.StandardTypes; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; - -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; + import static io.trino.spi.StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.TypeSignature.arrayType; -import static io.trino.spi.type.TypeSignatureParameter.typeVariable; -import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.util.Reflection.methodHandle; import static java.lang.String.format; -public class ListaggAggregationFunction - extends SqlAggregationFunction +@AggregationFunction(value = "listagg", isOrderSensitive = true) +@Description("concatenates the input values with the specified separator") +public final class ListaggAggregationFunction { - public static final ListaggAggregationFunction LISTAGG = new ListaggAggregationFunction(); - public static final String NAME = "listagg"; - private static final MethodHandle INPUT_FUNCTION = methodHandle(ListaggAggregationFunction.class, "input", Type.class, ListaggAggregationState.class, Block.class, Slice.class, boolean.class, Slice.class, boolean.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(ListaggAggregationFunction.class, "combine", Type.class, ListaggAggregationState.class, ListaggAggregationState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ListaggAggregationFunction.class, "output", Type.class, ListaggAggregationState.class, BlockBuilder.class); - private static final int MAX_OUTPUT_LENGTH = DEFAULT_MAX_PAGE_SIZE_IN_BYTES; private static final int MAX_OVERFLOW_FILLER_LENGTH = 65_536; - private ListaggAggregationFunction() - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .returnType(VARCHAR) - .argumentType(new TypeSignature(StandardTypes.VARCHAR, typeVariable("v"))) - .argumentType(new TypeSignature(StandardTypes.VARCHAR, typeVariable("d"))) - .argumentType(BOOLEAN) - .argumentType(new TypeSignature(StandardTypes.VARCHAR, typeVariable("f"))) - .argumentType(BOOLEAN) - .build()) - .nullable() - .argumentNullability(true, false, false, false, false) - .description("concatenates the input values with the specified separator") - .build(), - AggregationFunctionMetadata.builder() - .orderSensitive() - .intermediateType(VARCHAR.getTypeSignature()) - .intermediateType(BOOLEAN.getTypeSignature()) - .intermediateType(VARCHAR.getTypeSignature()) - .intermediateType(BOOLEAN.getTypeSignature()) - .intermediateType(arrayType(VARCHAR.getTypeSignature())) - .build()); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Type type = VARCHAR; - AccumulatorStateSerializer stateSerializer = new ListaggAggregationStateSerializer(type); - AccumulatorStateFactory stateFactory = new ListaggAggregationStateFactory(type); - - MethodHandle inputFunction = normalizeInputMethod( - INPUT_FUNCTION.bindTo(type), - boundSignature, - STATE, - NULLABLE_BLOCK_INPUT_CHANNEL, - INPUT_CHANNEL, - INPUT_CHANNEL, - INPUT_CHANNEL, - INPUT_CHANNEL, - BLOCK_INDEX); - MethodHandle combineFunction = COMBINE_FUNCTION.bindTo(type); - MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(type); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(combineFunction), - outputFunction, - ImmutableList.of(new AccumulatorStateDescriptor<>( - ListaggAggregationState.class, - stateSerializer, - stateFactory))); - } - - public static void input(Type type, ListaggAggregationState state, Block value, Slice separator, boolean overflowError, Slice overflowFiller, boolean showOverflowEntryCount, int position) + private ListaggAggregationFunction() {} + + @InputFunction + public static void input( + @AggregationState ListaggAggregationState state, + @BlockPosition @SqlType("VARCHAR") Block value, + @SqlType("VARCHAR") Slice separator, + @SqlType("BOOLEAN") boolean overflowError, + @SqlType("VARCHAR") Slice overflowFiller, + @SqlType("BOOLEAN") boolean showOverflowEntryCount, + @BlockIndex int position) { if (state.isEmpty()) { if (overflowFiller.length() > MAX_OVERFLOW_FILLER_LENGTH) { @@ -134,12 +64,11 @@ public static void input(Type type, ListaggAggregationState state, Block value, state.setOverflowFiller(overflowFiller); state.setShowOverflowEntryCount(showOverflowEntryCount); } - if (!value.isNull(position)) { - state.add(value, position); - } + state.add(value, position); } - public static void combine(Type type, ListaggAggregationState state, ListaggAggregationState otherState) + @CombineFunction + public static void combine(@AggregationState ListaggAggregationState state, @AggregationState ListaggAggregationState otherState) { Slice previousSeparator = state.getSeparator(); if (previousSeparator == null) { @@ -152,7 +81,8 @@ public static void combine(Type type, ListaggAggregationState state, ListaggAggr state.merge(otherState); } - public static void output(Type type, ListaggAggregationState state, BlockBuilder out) + @OutputFunction("VARCHAR") + public static void output(ListaggAggregationState state, BlockBuilder out) { if (state.isEmpty()) { out.appendNull(); @@ -163,7 +93,7 @@ public static void output(Type type, ListaggAggregationState state, BlockBuilder } @VisibleForTesting - protected static void outputState(ListaggAggregationState state, BlockBuilder out, int maxOutputLength) + public static void outputState(ListaggAggregationState state, BlockBuilder out, int maxOutputLength) { Slice separator = state.getSeparator(); int separatorLength = separator.length(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java index f7417e1a2ea7..e472d31c29eb 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java @@ -16,7 +16,11 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; +@AccumulatorStateMetadata( + stateFactoryClass = ListaggAggregationStateFactory.class, + stateSerializerClass = ListaggAggregationStateSerializer.class) public interface ListaggAggregationState extends AccumulatorState { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateFactory.java index 477c81766043..1c3f54fcb3a1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateFactory.java @@ -14,27 +14,19 @@ package io.trino.operator.aggregation.listagg; import io.trino.spi.function.AccumulatorStateFactory; -import io.trino.spi.type.Type; public class ListaggAggregationStateFactory implements AccumulatorStateFactory { - private final Type type; - - public ListaggAggregationStateFactory(Type type) - { - this.type = type; - } - @Override public ListaggAggregationState createSingleState() { - return new SingleListaggAggregationState(type); + return new SingleListaggAggregationState(); } @Override public ListaggAggregationState createGroupedState() { - return new GroupListaggAggregationState(type); + return new GroupListaggAggregationState(); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateSerializer.java index 1b7126fd7860..965fc2dcf05c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationStateSerializer.java @@ -32,14 +32,12 @@ public class ListaggAggregationStateSerializer implements AccumulatorStateSerializer { - private final Type elementType; private final Type arrayType; private final Type serializedType; - public ListaggAggregationStateSerializer(Type elementType) + public ListaggAggregationStateSerializer() { - this.elementType = elementType; - this.arrayType = new ArrayType(elementType); + this.arrayType = new ArrayType(VARCHAR); this.serializedType = RowType.anonymous(ImmutableList.of(VARCHAR, BOOLEAN, VARCHAR, BOOLEAN, arrayType)); } @@ -64,7 +62,7 @@ public void serialize(ListaggAggregationState state, BlockBuilder out) BlockBuilder stateElementsBlockBuilder = rowBlockBuilder.beginBlockEntry(); state.forEach((block, position) -> { - elementType.appendTo(block, position, stateElementsBlockBuilder); + VARCHAR.appendTo(block, position, stateElementsBlockBuilder); return true; }); rowBlockBuilder.closeEntry(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/SingleListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/SingleListaggAggregationState.java index 1b2e94672124..9f0125fd74ca 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/SingleListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/SingleListaggAggregationState.java @@ -16,11 +16,10 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.Type; import org.openjdk.jol.info.ClassLayout; import static com.google.common.base.Verify.verify; -import static java.util.Objects.requireNonNull; +import static io.trino.spi.type.VarcharType.VARCHAR; public class SingleListaggAggregationState implements ListaggAggregationState @@ -31,12 +30,6 @@ public class SingleListaggAggregationState private boolean overflowError; private Slice overflowFiller; private boolean showOverflowEntryCount; - private final Type type; - - public SingleListaggAggregationState(Type type) - { - this.type = requireNonNull(type, "type is null"); - } @Override public long getEstimatedSize() @@ -100,9 +93,9 @@ public boolean showOverflowEntryCount() public void add(Block block, int position) { if (blockBuilder == null) { - blockBuilder = type.createBlockBuilder(null, 16); + blockBuilder = VARCHAR.createBlockBuilder(null, 16); } - type.appendTo(block, position, blockBuilder); + VARCHAR.appendTo(block, position, blockBuilder); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java deleted file mode 100644 index 3d84ef00f1a2..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.minmaxby; - -import com.google.common.collect.ImmutableList; -import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.state.BlockPositionState; -import io.trino.operator.aggregation.state.BlockPositionStateSerializer; -import io.trino.operator.aggregation.state.NullableBooleanState; -import io.trino.operator.aggregation.state.NullableBooleanStateSerializer; -import io.trino.operator.aggregation.state.NullableDoubleState; -import io.trino.operator.aggregation.state.NullableDoubleStateSerializer; -import io.trino.operator.aggregation.state.NullableLongState; -import io.trino.operator.aggregation.state.NullableLongStateSerializer; -import io.trino.operator.aggregation.state.NullableState; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorState; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.util.MinMaxCompare; - -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.util.Optional; - -import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static io.trino.util.MinMaxCompare.getMinMaxCompareFunctionDependencies; -import static io.trino.util.MinMaxCompare.getMinMaxCompareOperatorType; -import static java.lang.invoke.MethodHandles.explicitCastArguments; -import static java.lang.invoke.MethodHandles.insertArguments; -import static java.lang.invoke.MethodHandles.lookup; -import static java.lang.invoke.MethodType.methodType; - -public abstract class AbstractMinMaxBy - extends SqlAggregationFunction -{ - private final boolean min; - - protected AbstractMinMaxBy(boolean min, String description) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name((min ? "min" : "max") + "_by") - .orderableTypeParameter("K") - .typeVariable("V") - .returnType(new TypeSignature("V")) - .argumentType(new TypeSignature("V")) - .argumentType(new TypeSignature("K")) - .build()) - .argumentNullability(true, false) - .description(description) - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(new TypeSignature("K")) - .intermediateType(new TypeSignature("V")) - .build()); - this.min = min; - } - - @Override - public FunctionDependencyDeclaration getFunctionDependencies() - { - return getMinMaxCompareFunctionDependencies(new TypeSignature("K"), min); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) - { - try { - Type keyType = boundSignature.getArgumentType(1); - Type valueType = boundSignature.getArgumentType(0); - - MethodHandle inputMethod = generateInput(keyType, valueType, functionDependencies); - MethodHandle combineMethod = generateCombine(keyType, valueType, functionDependencies); - MethodHandle outputMethod = generateOutput(keyType, valueType); - - return new AggregationMetadata( - inputMethod, - Optional.empty(), - Optional.of(combineMethod), - outputMethod, - ImmutableList.of( - getAccumulatorStateDescriptor(keyType), - getAccumulatorStateDescriptor(valueType))); - } - catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } - } - - private static AccumulatorStateDescriptor getAccumulatorStateDescriptor(Type type) - { - Class stateClass = getStateClass(type); - if (stateClass.equals(BlockPositionState.class)) { - return new AccumulatorStateDescriptor<>( - BlockPositionState.class, - new BlockPositionStateSerializer(type), - generateStateFactory(BlockPositionState.class)); - } - return getAccumulatorStateDescriptor(stateClass, type); - } - - private static AccumulatorStateDescriptor getAccumulatorStateDescriptor(Class stateClass, Type type) - { - return new AccumulatorStateDescriptor<>( - stateClass, - getStateSerializer(stateClass, type), - generateStateFactory(stateClass)); - } - - private MethodHandle generateInput(Type keyType, Type valueType, FunctionDependencies functionDependencies) - throws ReflectiveOperationException - { - MethodHandle input = lookup().findStatic( - AbstractMinMaxBy.class, - "input", - methodType(void.class, - MethodHandle.class, - MethodHandle.class, - MethodHandle.class, - NullableState.class, - NullableState.class, - Block.class, - Block.class, - int.class)); - - Class keyState = getStateClass(keyType); - Class valueState = getStateClass(valueType); - - MethodHandle compareStateBlockPosition = generateCompareStateBlockPosition(keyType, functionDependencies, keyState); - MethodHandle setKeyState = getSetStateValue(keyType, keyState); - MethodHandle setValueState = getSetStateValue(valueType, valueState); - input = insertArguments(input, 0, compareStateBlockPosition, setKeyState, setValueState); - return explicitCastArguments(input, methodType(void.class, keyState, valueState, Block.class, Block.class, int.class)); - } - - private static void input( - MethodHandle compareStateBlockPosition, - MethodHandle setKeyState, - MethodHandle setValueState, - NullableState keyState, - NullableState valueState, - Block value, - Block key, - int position) - throws Throwable - { - if (keyState.isNull() || (boolean) compareStateBlockPosition.invoke(key, position, keyState)) { - setKeyState.invoke(keyState, key, position); - setValueState.invoke(valueState, value, position); - } - } - - private MethodHandle generateCombine(Type keyType, Type valueType, FunctionDependencies functionDependencies) - throws ReflectiveOperationException - { - MethodHandle combine = lookup().findStatic( - AbstractMinMaxBy.class, - "combine", - methodType(void.class, - MethodHandle.class, - MethodHandle.class, - MethodHandle.class, - NullableState.class, - NullableState.class, - NullableState.class, - NullableState.class)); - - Class keyState = getStateClass(keyType); - Class valueState = getStateClass(valueType); - - MethodHandle compareStateBlockPosition = generateCompareStateState(keyType, functionDependencies, keyState); - MethodHandle setKeyState = lookup().findVirtual(keyState, "set", methodType(void.class, keyState)); - MethodHandle setValueState = lookup().findVirtual(valueState, "set", methodType(void.class, valueState)); - combine = insertArguments(combine, 0, compareStateBlockPosition, setKeyState, setValueState); - return explicitCastArguments(combine, methodType(void.class, keyState, valueState, keyState, valueState)); - } - - private static void combine( - MethodHandle compareStateState, - MethodHandle setKeyState, - MethodHandle setValueState, - NullableState keyState, - NullableState valueState, - NullableState otherKeyState, - NullableState otherValueState) - throws Throwable - { - if (otherKeyState.isNull()) { - return; - } - if (keyState.isNull() || (boolean) compareStateState.invoke(otherKeyState, keyState)) { - setKeyState.invoke(keyState, otherKeyState); - setValueState.invoke(valueState, otherValueState); - } - } - - private static MethodHandle generateOutput(Type keyType, Type valueType) - throws ReflectiveOperationException - { - Class keyState = getStateClass(keyType); - Class valueState = getStateClass(valueType); - MethodHandle writeState = lookup().findStatic(AbstractMinMaxBy.class, "writeState", methodType(void.class, Type.class, valueState, BlockBuilder.class)) - .bindTo(valueType); - MethodHandle output = lookup().findStatic( - AbstractMinMaxBy.class, - "output", - methodType(void.class, MethodHandle.class, NullableState.class, NullableState.class, BlockBuilder.class)); - output = output.bindTo(writeState); - return explicitCastArguments(output, methodType(void.class, keyState, valueState, BlockBuilder.class)); - } - - private static void output( - MethodHandle valueWriter, - NullableState keyState, - NullableState valueState, - BlockBuilder blockBuilder) - throws Throwable - { - if (keyState.isNull() || valueState.isNull()) { - blockBuilder.appendNull(); - return; - } - valueWriter.invoke(valueState, blockBuilder); - } - - @UsedByGeneratedCode - private static void writeState(Type type, NullableLongState state, BlockBuilder output) - { - type.writeLong(output, state.getValue()); - } - - @UsedByGeneratedCode - private static void writeState(Type type, NullableDoubleState state, BlockBuilder output) - { - type.writeDouble(output, state.getValue()); - } - - @UsedByGeneratedCode - private static void writeState(Type type, NullableBooleanState state, BlockBuilder output) - { - type.writeBoolean(output, state.getValue()); - } - - @UsedByGeneratedCode - private static void writeState(Type type, BlockPositionState state, BlockBuilder output) - { - type.appendTo(state.getBlock(), state.getPosition(), output); - } - - private static Class getStateClass(Type type) - { - if (type.getJavaType().equals(long.class)) { - return NullableLongState.class; - } - if (type.getJavaType().equals(double.class)) { - return NullableDoubleState.class; - } - if (type.getJavaType().equals(boolean.class)) { - return NullableBooleanState.class; - } - return BlockPositionState.class; - } - - @SuppressWarnings("unchecked") - private static AccumulatorStateSerializer getStateSerializer(Class state, Type type) - { - if (NullableLongState.class.equals(state)) { - return (AccumulatorStateSerializer) new NullableLongStateSerializer(type); - } - if (NullableDoubleState.class.equals(state)) { - return (AccumulatorStateSerializer) new NullableDoubleStateSerializer(type); - } - if (NullableBooleanState.class.equals(state)) { - return (AccumulatorStateSerializer) new NullableBooleanStateSerializer(type); - } - if (BlockPositionState.class.equals(state)) { - return (AccumulatorStateSerializer) new BlockPositionStateSerializer(type); - } - throw new IllegalArgumentException("Unsupported state class: " + state); - } - - private static MethodHandle getSetStateValue(Type type, Class stateClass) - throws ReflectiveOperationException - { - if (stateClass.equals(BlockPositionState.class)) { - return lookup().findStatic(AbstractMinMaxBy.class, "setStateValue", methodType(void.class, BlockPositionState.class, Block.class, int.class)); - } - return lookup().findStatic(AbstractMinMaxBy.class, "setStateValue", methodType(void.class, Type.class, stateClass, Block.class, int.class)) - .bindTo(type); - } - - @UsedByGeneratedCode - private static void setStateValue(BlockPositionState state, Block block, int position) - { - state.setBlock(block); - state.setPosition(position); - } - - @UsedByGeneratedCode - private static void setStateValue(Type valueType, NullableLongState state, Block block, int position) - { - if (block.isNull(position)) { - state.setNull(true); - } - else { - state.setNull(false); - state.setValue(valueType.getLong(block, position)); - } - } - - @UsedByGeneratedCode - private static void setStateValue(Type valueType, NullableDoubleState state, Block block, int position) - { - if (block.isNull(position)) { - state.setNull(true); - } - else { - state.setNull(false); - state.setValue(valueType.getDouble(block, position)); - } - } - - @UsedByGeneratedCode - private static void setStateValue(Type valueType, NullableBooleanState state, Block block, int position) - { - if (block.isNull(position)) { - state.setNull(true); - } - else { - state.setNull(false); - state.setValue(valueType.getBoolean(block, position)); - } - } - - private MethodHandle generateCompareStateBlockPosition(Type type, FunctionDependencies functionDependencies, Class state) - throws ReflectiveOperationException - { - if (state.equals(BlockPositionState.class)) { - MethodHandle comparisonMethod = lookup().findStatic(AbstractMinMaxBy.class, "compareStateBlockPosition", methodType(long.class, MethodHandle.class, Block.class, int.class, BlockPositionState.class)) - .bindTo(functionDependencies.getOperatorInvoker(getMinMaxCompareOperatorType(min), ImmutableList.of(type, type), simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION)).getMethodHandle()); - return MinMaxCompare.comparisonToMinMaxResult(min, comparisonMethod); - } - MethodHandle minMaxMethod = MinMaxCompare.getMinMaxCompare(functionDependencies, type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, NEVER_NULL), min); - MethodHandle stateGetValue = lookup().findVirtual(state, "getValue", methodType(type.getJavaType())); - return MethodHandles.filterArguments(minMaxMethod, 2, stateGetValue); - } - - private static long compareStateBlockPosition(MethodHandle blockPositionBlockPositionOperator, Block left, int leftPosition, BlockPositionState state) - throws Throwable - { - return (long) blockPositionBlockPositionOperator.invokeExact(left, leftPosition, state.getBlock(), state.getPosition()); - } - - private MethodHandle generateCompareStateState(Type type, FunctionDependencies functionDependencies, Class state) - throws ReflectiveOperationException - { - if (state.equals(BlockPositionState.class)) { - MethodHandle comparisonMethod = lookup().findStatic(AbstractMinMaxBy.class, "compareStateState", methodType(long.class, MethodHandle.class, BlockPositionState.class, BlockPositionState.class)) - .bindTo(functionDependencies.getOperatorInvoker(getMinMaxCompareOperatorType(min), ImmutableList.of(type, type), simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION)).getMethodHandle()); - return MinMaxCompare.comparisonToMinMaxResult(min, comparisonMethod); - } - MethodHandle maxMaxMethod = MinMaxCompare.getMinMaxCompare(functionDependencies, type, simpleConvention(FAIL_ON_NULL, NEVER_NULL, NEVER_NULL), min); - MethodHandle stateGetValue = lookup().findVirtual(state, "getValue", methodType(type.getJavaType())); - return MethodHandles.filterArguments(maxMaxMethod, 0, stateGetValue, stateGetValue); - } - - private static long compareStateState(MethodHandle blockPositionBlockPositionOperator, BlockPositionState state, BlockPositionState otherState) - throws Throwable - { - return (long) blockPositionBlockPositionOperator.invokeExact(state.getBlock(), state.getPosition(), otherState.getBlock(), otherState.getPosition()); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java deleted file mode 100644 index 3eb90ca584cb..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.minmaxby; - -import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.TypedKeyValueHeap; -import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.util.MinMaxCompare; - -import java.lang.invoke.MethodHandle; -import java.util.Optional; - -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; -import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.TypeSignature.arrayType; -import static io.trino.util.Failures.checkCondition; -import static io.trino.util.MinMaxCompare.getMinMaxCompare; -import static io.trino.util.Reflection.methodHandle; -import static java.lang.Math.toIntExact; - -public abstract class AbstractMinMaxByNAggregationFunction - extends SqlAggregationFunction -{ - private static final MethodHandle INPUT_FUNCTION = methodHandle(AbstractMinMaxByNAggregationFunction.class, "input", MethodHandle.class, Type.class, Type.class, MinMaxByNState.class, Block.class, Block.class, long.class, int.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle(AbstractMinMaxByNAggregationFunction.class, "combine", MinMaxByNState.class, MinMaxByNState.class); - private static final MethodHandle OUTPUT_FUNCTION = methodHandle(AbstractMinMaxByNAggregationFunction.class, "output", ArrayType.class, MinMaxByNState.class, BlockBuilder.class); - private static final long MAX_NUMBER_OF_VALUES = 10_000; - - private final boolean min; - - protected AbstractMinMaxByNAggregationFunction(String name, boolean min, String description) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(name) - .typeVariable("V") - .orderableTypeParameter("K") - .returnType(arrayType(new TypeSignature("V"))) - .argumentType(new TypeSignature("V")) - .argumentType(new TypeSignature("K")) - .argumentType(BIGINT) - .build()) - .argumentNullability(true, false, false) - .description(description) - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(BIGINT.getTypeSignature()) - .intermediateType(arrayType(new TypeSignature("K"))) - .intermediateType(arrayType(new TypeSignature("V"))) - .build()); - this.min = min; - } - - @Override - public FunctionDependencyDeclaration getFunctionDependencies(BoundSignature boundSignature) - { - return MinMaxCompare.getMinMaxCompareFunctionDependencies(boundSignature.getArgumentTypes().get(1).getTypeSignature(), min); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) - { - Type keyType = boundSignature.getArgumentTypes().get(1); - Type valueType = boundSignature.getArgumentTypes().get(0); - MethodHandle keyComparisonMethod = getMinMaxCompare(functionDependencies, keyType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), min); - - MinMaxByNStateSerializer stateSerializer = new MinMaxByNStateSerializer(keyComparisonMethod, keyType, valueType); - ArrayType outputType = new ArrayType(valueType); - - MethodHandle inputFunction = INPUT_FUNCTION.bindTo(keyComparisonMethod).bindTo(valueType).bindTo(keyType); - inputFunction = normalizeInputMethod(inputFunction, boundSignature, STATE, NULLABLE_BLOCK_INPUT_CHANNEL, BLOCK_INPUT_CHANNEL, INPUT_CHANNEL, BLOCK_INDEX); - - return new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - OUTPUT_FUNCTION.bindTo(outputType), - ImmutableList.of(new AccumulatorStateDescriptor<>( - MinMaxByNState.class, - stateSerializer, - new MinMaxByNStateFactory()))); - } - - public static void input(MethodHandle keyComparisonMethod, Type valueType, Type keyType, MinMaxByNState state, Block value, Block key, long n, int blockIndex) - { - TypedKeyValueHeap heap = state.getTypedKeyValueHeap(); - if (heap == null) { - if (n <= 0) { - throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "third argument of max_by/min_by must be a positive integer"); - } - checkCondition(n <= MAX_NUMBER_OF_VALUES, INVALID_FUNCTION_ARGUMENT, "third argument of max_by/min_by must be less than or equal to %s; found %s", MAX_NUMBER_OF_VALUES, n); - heap = new TypedKeyValueHeap(keyComparisonMethod, keyType, valueType, toIntExact(n)); - state.setTypedKeyValueHeap(heap); - } - - long startSize = heap.getEstimatedSize(); - if (!key.isNull(blockIndex)) { - heap.add(key, value, blockIndex); - } - state.addMemoryUsage(heap.getEstimatedSize() - startSize); - } - - public static void combine(MinMaxByNState state, MinMaxByNState otherState) - { - TypedKeyValueHeap otherHeap = otherState.getTypedKeyValueHeap(); - if (otherHeap == null) { - return; - } - TypedKeyValueHeap heap = state.getTypedKeyValueHeap(); - if (heap == null) { - state.setTypedKeyValueHeap(otherHeap); - return; - } - long startSize = heap.getEstimatedSize(); - heap.addAll(otherHeap); - state.addMemoryUsage(heap.getEstimatedSize() - startSize); - } - - public static void output(ArrayType outputType, MinMaxByNState state, BlockBuilder out) - { - TypedKeyValueHeap heap = state.getTypedKeyValueHeap(); - if (heap == null || heap.isEmpty()) { - out.appendNull(); - return; - } - - Type elementType = outputType.getElementType(); - - BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - BlockBuilder reversedBlockBuilder = elementType.createBlockBuilder(null, heap.getCapacity()); - long startSize = heap.getEstimatedSize(); - heap.popAll(reversedBlockBuilder); - state.addMemoryUsage(heap.getEstimatedSize() - startSize); - - for (int i = reversedBlockBuilder.getPositionCount() - 1; i >= 0; i--) { - elementType.appendTo(reversedBlockBuilder, i, arrayBlockBuilder); - } - out.closeEntry(); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByNAggregationFunction.java deleted file mode 100644 index 2ced99130ddf..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByNAggregationFunction.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.minmaxby; - -import io.trino.type.BlockTypeOperators; - -public class MaxByNAggregationFunction - extends AbstractMinMaxByNAggregationFunction -{ - private static final String NAME = "max_by"; - - public MaxByNAggregationFunction(BlockTypeOperators blockTypeOperators) - { - super(NAME, false, "Returns the values of the first argument associated with the maximum values of the second argument"); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNStateFactory.java deleted file mode 100644 index 97fde859cc87..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNStateFactory.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.minmaxby; - -import io.trino.array.ObjectBigArray; -import io.trino.operator.aggregation.TypedKeyValueHeap; -import io.trino.operator.aggregation.state.AbstractGroupedAccumulatorState; -import io.trino.spi.function.AccumulatorState; -import io.trino.spi.function.AccumulatorStateFactory; -import org.openjdk.jol.info.ClassLayout; - -public class MinMaxByNStateFactory - implements AccumulatorStateFactory -{ - @Override - public MinMaxByNState createSingleState() - { - return new SingleMinMaxByNState(); - } - - @Override - public MinMaxByNState createGroupedState() - { - return new GroupedMinMaxByNState(); - } - - public static class GroupedMinMaxByNState - extends AbstractGroupedAccumulatorState - implements MinMaxByNState - { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedMinMaxByNState.class).instanceSize(); - private final ObjectBigArray heaps = new ObjectBigArray<>(); - private long size; - - @Override - public void ensureCapacity(long size) - { - heaps.ensureCapacity(size); - } - - @Override - public long getEstimatedSize() - { - return INSTANCE_SIZE + heaps.sizeOf() + size; - } - - @Override - public TypedKeyValueHeap getTypedKeyValueHeap() - { - return heaps.get(getGroupId()); - } - - @Override - public void setTypedKeyValueHeap(TypedKeyValueHeap value) - { - TypedKeyValueHeap previous = getTypedKeyValueHeap(); - if (previous != null) { - size -= previous.getEstimatedSize(); - } - heaps.set(getGroupId(), value); - size += value.getEstimatedSize(); - } - - @Override - public void addMemoryUsage(long memory) - { - size += memory; - } - } - - public static class SingleMinMaxByNState - implements MinMaxByNState - { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMinMaxByNState.class).instanceSize(); - private TypedKeyValueHeap typedKeyValueHeap; - - public SingleMinMaxByNState() {} - - // for copying - private SingleMinMaxByNState(TypedKeyValueHeap typedKeyValueHeap) - { - this.typedKeyValueHeap = typedKeyValueHeap; - } - - @Override - public long getEstimatedSize() - { - long estimatedSize = INSTANCE_SIZE; - if (typedKeyValueHeap != null) { - estimatedSize += typedKeyValueHeap.getEstimatedSize(); - } - return estimatedSize; - } - - @Override - public TypedKeyValueHeap getTypedKeyValueHeap() - { - return typedKeyValueHeap; - } - - @Override - public void setTypedKeyValueHeap(TypedKeyValueHeap typedKeyValueHeap) - { - this.typedKeyValueHeap = typedKeyValueHeap; - } - - @Override - public void addMemoryUsage(long memory) - { - } - - @Override - public AccumulatorState copy() - { - TypedKeyValueHeap typedKeyValueHeapCopy = null; - if (typedKeyValueHeap != null) { - typedKeyValueHeapCopy = typedKeyValueHeap.copy(); - } - return new SingleMinMaxByNState(typedKeyValueHeapCopy); - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNStateSerializer.java deleted file mode 100644 index f0cafa5c95fd..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNStateSerializer.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.minmaxby; - -import io.trino.operator.aggregation.TypedKeyValueHeap; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.Type; - -import java.lang.invoke.MethodHandle; - -public class MinMaxByNStateSerializer - implements AccumulatorStateSerializer -{ - private final MethodHandle keyComparisonMethod; - private final Type keyType; - private final Type valueType; - private final Type serializedType; - - public MinMaxByNStateSerializer(MethodHandle keyComparisonMethod, Type keyType, Type valueType) - { - this.keyComparisonMethod = keyComparisonMethod; - this.keyType = keyType; - this.valueType = valueType; - this.serializedType = TypedKeyValueHeap.getSerializedType(keyType, valueType); - } - - @Override - public Type getSerializedType() - { - return serializedType; - } - - @Override - public void serialize(MinMaxByNState state, BlockBuilder out) - { - TypedKeyValueHeap heap = state.getTypedKeyValueHeap(); - if (heap == null) { - out.appendNull(); - return; - } - - heap.serialize(out); - } - - @Override - public void deserialize(Block block, int index, MinMaxByNState state) - { - Block currentBlock = (Block) serializedType.getObject(block, index); - state.setTypedKeyValueHeap(TypedKeyValueHeap.deserialize(currentBlock, keyType, valueType, keyComparisonMethod)); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java new file mode 100644 index 000000000000..691414610afa --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxbyn; + +import io.trino.operator.aggregation.NullablePosition; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +@AggregationFunction("max_by") +@Description("Returns the values of the first argument associated with the maximum values of the second argument") +public final class MaxByNAggregationFunction +{ + private MaxByNAggregationFunction() {} + + @InputFunction + @TypeParameter("K") + @TypeParameter("V") + public static void input( + @AggregationState({"K", "V"}) MaxByNState state, + @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockPosition @SqlType("K") Block keyBlock, + @SqlType("BIGINT") long n, + @BlockIndex int blockIndex) + { + state.initialize(n); + state.add(keyBlock, valueBlock, blockIndex); + } + + @CombineFunction + public static void combine( + @AggregationState({"K", "V"}) MaxByNState state, + @AggregationState({"K", "V"}) MaxByNState otherState) + { + state.merge(otherState); + } + + @OutputFunction("array(V)") + public static void output(@AggregationState({"K", "V"}) MaxByNState state, BlockBuilder out) + { + state.popAll(out); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNState.java similarity index 58% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByAggregationFunction.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNState.java index 460463c26364..bba6cd9ee1ea 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNState.java @@ -11,15 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation.minmaxby; +package io.trino.operator.aggregation.minmaxbyn; -public class MaxByAggregationFunction - extends AbstractMinMaxBy -{ - public static final MaxByAggregationFunction MAX_BY = new MaxByAggregationFunction(); +import io.trino.spi.function.AccumulatorStateMetadata; - public MaxByAggregationFunction() - { - super(false, "Returns the value of the first argument, associated with the maximum value of the second argument"); - } -} +@AccumulatorStateMetadata( + stateFactoryClass = MaxByNStateFactory.class, + stateSerializerClass = MaxByNStateSerializer.class, + typeParameters = {"K", "V"}, + serializedType = "ROW(BIGINT, ARRAY(K), ARRAY(V))") +public interface MaxByNState + extends MinMaxByNState +{} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java new file mode 100644 index 000000000000..57b9ee42eee5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxbyn; + +import io.trino.operator.aggregation.minmaxbyn.MinMaxByNStateFactory.GroupedMinMaxByNState; +import io.trino.operator.aggregation.minmaxbyn.MinMaxByNStateFactory.SingleMinMaxByNState; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.Convention; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; + +import java.lang.invoke.MethodHandle; +import java.util.function.Function; +import java.util.function.LongFunction; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.util.Failures.checkCondition; +import static java.lang.Math.toIntExact; + +public class MaxByNStateFactory + implements AccumulatorStateFactory +{ + private static final long MAX_NUMBER_OF_VALUES = 10_000; + private final LongFunction heapFactory; + private final Function deserializer; + + public MaxByNStateFactory( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) + MethodHandle compare, + @TypeParameter("K") Type keyType, + @TypeParameter("V") Type valueType) + { + heapFactory = n -> { + checkCondition(n > 0, INVALID_FUNCTION_ARGUMENT, "third argument of max_by must be a positive integer"); + checkCondition( + n <= MAX_NUMBER_OF_VALUES, + INVALID_FUNCTION_ARGUMENT, + "third argument of max_by must be less than or equal to %s; found %s", + MAX_NUMBER_OF_VALUES, + n); + return new TypedKeyValueHeap(false, compare, keyType, valueType, toIntExact(n)); + }; + deserializer = rowBlock -> TypedKeyValueHeap.deserialize(false, compare, keyType, valueType, rowBlock); + } + + @Override + public MaxByNState createSingleState() + { + return new SingleMaxByNState(heapFactory, deserializer); + } + + @Override + public MaxByNState createGroupedState() + { + return new GroupedMaxByNState(heapFactory, deserializer); + } + + private static class GroupedMaxByNState + extends GroupedMinMaxByNState + implements MaxByNState + { + public GroupedMaxByNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + } + + private static class SingleMaxByNState + extends SingleMinMaxByNState + implements MaxByNState + { + public SingleMaxByNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + + public SingleMaxByNState(SingleMaxByNState state) + { + super(state); + } + + @Override + public AccumulatorState copy() + { + return new SingleMaxByNState(this); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateSerializer.java similarity index 61% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/MinNAggregationFunction.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateSerializer.java index e127be5cc0b2..86770af640dc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateSerializer.java @@ -11,17 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation; +package io.trino.operator.aggregation.minmaxbyn; -import io.trino.type.BlockTypeOperators; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; -public class MinNAggregationFunction - extends AbstractMinMaxNAggregationFunction +public class MaxByNStateSerializer + extends MinMaxByNStateSerializer { - private static final String NAME = "min"; - - public MinNAggregationFunction(BlockTypeOperators blockTypeOperators) + public MaxByNStateSerializer(@TypeParameter("ROW(BIGINT, ARRAY(K), ARRAY(V))") Type serializedType) { - super(NAME, true, "Returns the minimum values of the argument"); + super(serializedType); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java new file mode 100644 index 000000000000..e36d733095e7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxbyn; + +import io.trino.operator.aggregation.NullablePosition; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +@AggregationFunction("min_by") +@Description("Returns the values of the first argument associated with the minimum values of the second argument") +public final class MinByNAggregationFunction +{ + private MinByNAggregationFunction() {} + + @InputFunction + @TypeParameter("K") + @TypeParameter("V") + public static void input( + @AggregationState({"K", "V"}) MinByNState state, + @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockPosition @SqlType("K") Block keyBlock, + @SqlType("BIGINT") long n, + @BlockIndex int blockIndex) + { + state.initialize(n); + state.add(keyBlock, valueBlock, blockIndex); + } + + @CombineFunction + public static void combine( + @AggregationState({"K", "V"}) MinByNState state, + @AggregationState({"K", "V"}) MinByNState otherState) + { + state.merge(otherState); + } + + @OutputFunction("array(V)") + public static void output(@AggregationState({"K", "V"}) MinByNState state, BlockBuilder out) + { + state.popAll(out); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNState.java similarity index 58% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByAggregationFunction.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNState.java index f813ac3bab93..0bd96359d0e3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNState.java @@ -11,15 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation.minmaxby; +package io.trino.operator.aggregation.minmaxbyn; -public class MinByAggregationFunction - extends AbstractMinMaxBy -{ - public static final MinByAggregationFunction MIN_BY = new MinByAggregationFunction(); +import io.trino.spi.function.AccumulatorStateMetadata; - public MinByAggregationFunction() - { - super(true, "Returns the value of the first argument, associated with the minimum value of the second argument"); - } -} +@AccumulatorStateMetadata( + stateFactoryClass = MinByNStateFactory.class, + stateSerializerClass = MinByNStateSerializer.class, + typeParameters = {"K", "V"}, + serializedType = "ROW(BIGINT, ARRAY(K), ARRAY(V))") +public interface MinByNState + extends MinMaxByNState +{} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java new file mode 100644 index 000000000000..bf19fccb41a1 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxbyn; + +import io.trino.operator.aggregation.minmaxbyn.MinMaxByNStateFactory.GroupedMinMaxByNState; +import io.trino.operator.aggregation.minmaxbyn.MinMaxByNStateFactory.SingleMinMaxByNState; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.Convention; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; + +import java.lang.invoke.MethodHandle; +import java.util.function.Function; +import java.util.function.LongFunction; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.util.Failures.checkCondition; +import static java.lang.Math.toIntExact; + +public class MinByNStateFactory + implements AccumulatorStateFactory +{ + private static final long MAX_NUMBER_OF_VALUES = 10_000; + private final LongFunction heapFactory; + private final Function deserializer; + + public MinByNStateFactory( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) + MethodHandle compare, + @TypeParameter("K") Type keyType, + @TypeParameter("V") Type valueType) + { + heapFactory = n -> { + checkCondition(n > 0, INVALID_FUNCTION_ARGUMENT, "third argument of min_by must be a positive integer"); + checkCondition( + n <= MAX_NUMBER_OF_VALUES, + INVALID_FUNCTION_ARGUMENT, + "third argument of min_by must be less than or equal to %s; found %s", + MAX_NUMBER_OF_VALUES, + n); + return new TypedKeyValueHeap(true, compare, keyType, valueType, toIntExact(n)); + }; + deserializer = rowBlock -> TypedKeyValueHeap.deserialize(true, compare, keyType, valueType, rowBlock); + } + + @Override + public MinByNState createSingleState() + { + return new SingleMinByNState(heapFactory, deserializer); + } + + @Override + public MinByNState createGroupedState() + { + return new GroupedMinByNState(heapFactory, deserializer); + } + + private static class GroupedMinByNState + extends GroupedMinMaxByNState + implements MinByNState + { + public GroupedMinByNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + } + + private static class SingleMinByNState + extends SingleMinMaxByNState + implements MinByNState + { + public SingleMinByNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + + public SingleMinByNState(SingleMinByNState state) + { + super(state); + } + + @Override + public AccumulatorState copy() + { + return new SingleMinByNState(this); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateSerializer.java similarity index 61% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/MaxNAggregationFunction.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateSerializer.java index cabd71ead03d..fbb058bde0a4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateSerializer.java @@ -11,17 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation; +package io.trino.operator.aggregation.minmaxbyn; -import io.trino.type.BlockTypeOperators; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; -public class MaxNAggregationFunction - extends AbstractMinMaxNAggregationFunction +public class MinByNStateSerializer + extends MinMaxByNStateSerializer { - private static final String NAME = "max"; - - public MaxNAggregationFunction(BlockTypeOperators blockTypeOperators) + public MinByNStateSerializer(@TypeParameter("ROW(BIGINT, ARRAY(K), ARRAY(V))") Type serializedType) { - super(NAME, false, "Returns the maximum values of the argument"); + super(serializedType); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java new file mode 100644 index 000000000000..d5bcd8e6c116 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxbyn; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorState; + +public interface MinMaxByNState + extends AccumulatorState +{ + /** + * Initialize the state if not already initialized. Only the first call is processed and + * all subsequent calls are ignored. + */ + void initialize(long n); + + /** + * Adds the value to this state. + */ + void add(Block keyBlock, Block valueBlock, int position); + + /** + * Merge with the specified state. + * The supplied state should not be used after this method is called, because + * the internal details of the state may be reused in this state. + */ + void merge(MinMaxByNState other); + + /** + * Writes all values to the supplied block builder as an array entry. + * After this method is called, the current state will be empty. + */ + void popAll(BlockBuilder out); + + /** + * Write this state to the specified block builder. + */ + void serialize(BlockBuilder out); + + /** + * Read the state to the specified block builder. + * + * @throws IllegalStateException if state is already initialized + */ + void deserialize(Block rowBlock); +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java new file mode 100644 index 000000000000..2157943a68ad --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java @@ -0,0 +1,269 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxbyn; + +import io.trino.array.ObjectBigArray; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.GroupedAccumulatorState; +import org.openjdk.jol.info.ClassLayout; + +import java.util.function.Function; +import java.util.function.LongFunction; + +import static com.google.common.base.Preconditions.checkState; + +public final class MinMaxByNStateFactory +{ + private abstract static class AbstractMinMaxByNState + implements MinMaxByNState + { + abstract TypedKeyValueHeap getTypedKeyValueHeap(); + } + + public abstract static class GroupedMinMaxByNState + extends AbstractMinMaxByNState + implements GroupedAccumulatorState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedMinMaxByNState.class).instanceSize(); + + private final LongFunction heapFactory; + private final Function deserializer; + + private final ObjectBigArray heaps = new ObjectBigArray<>(); + private long groupId; + private long size; + + public GroupedMinMaxByNState(LongFunction heapFactory, Function deserializer) + { + this.heapFactory = heapFactory; + this.deserializer = deserializer; + } + + @Override + public final void setGroupId(long groupId) + { + this.groupId = groupId; + } + + @Override + public final void ensureCapacity(long size) + { + heaps.ensureCapacity(size); + } + + @Override + public final long getEstimatedSize() + { + return INSTANCE_SIZE + heaps.sizeOf() + size; + } + + @Override + public final void initialize(long n) + { + if (getTypedKeyValueHeap() == null) { + TypedKeyValueHeap typedHeap = heapFactory.apply(n); + setTypedKeyValueHeap(typedHeap); + size += typedHeap.getEstimatedSize(); + } + } + + @Override + public final void add(Block keyBlock, Block valueBlock, int position) + { + TypedKeyValueHeap typedHeap = getTypedKeyValueHeap(); + + size -= typedHeap.getEstimatedSize(); + typedHeap.add(keyBlock, valueBlock, position); + size += typedHeap.getEstimatedSize(); + } + + @Override + public final void merge(MinMaxByNState other) + { + TypedKeyValueHeap otherTypedHeap = ((AbstractMinMaxByNState) other).getTypedKeyValueHeap(); + if (otherTypedHeap == null) { + return; + } + + TypedKeyValueHeap typedHeap = getTypedKeyValueHeap(); + if (typedHeap == null) { + setTypedKeyValueHeap(otherTypedHeap); + size += otherTypedHeap.getEstimatedSize(); + } + else { + size -= typedHeap.getEstimatedSize(); + typedHeap.addAll(otherTypedHeap); + size += typedHeap.getEstimatedSize(); + } + } + + @Override + public final void popAll(BlockBuilder out) + { + TypedKeyValueHeap typedHeap = getTypedKeyValueHeap(); + if (typedHeap == null || typedHeap.isEmpty()) { + out.appendNull(); + return; + } + + BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); + + size -= typedHeap.getEstimatedSize(); + typedHeap.popAllReverse(arrayBlockBuilder); + size += typedHeap.getEstimatedSize(); + + out.closeEntry(); + } + + @Override + public final void serialize(BlockBuilder out) + { + TypedKeyValueHeap typedHeap = getTypedKeyValueHeap(); + if (typedHeap == null) { + out.appendNull(); + } + else { + typedHeap.serialize(out); + } + } + + @Override + public final void deserialize(Block rowBlock) + { + checkState(getTypedKeyValueHeap() == null, "State already initialized"); + + TypedKeyValueHeap typedHeap = deserializer.apply(rowBlock); + setTypedKeyValueHeap(typedHeap); + size += typedHeap.getEstimatedSize(); + } + + @Override + final TypedKeyValueHeap getTypedKeyValueHeap() + { + return heaps.get(groupId); + } + + private void setTypedKeyValueHeap(TypedKeyValueHeap value) + { + heaps.set(groupId, value); + } + } + + public abstract static class SingleMinMaxByNState + extends AbstractMinMaxByNState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMinMaxByNState.class).instanceSize(); + + private final LongFunction heapFactory; + private final Function deserializer; + + private TypedKeyValueHeap typedHeap; + + public SingleMinMaxByNState(LongFunction heapFactory, Function deserializer) + { + this.heapFactory = heapFactory; + this.deserializer = deserializer; + } + + // for copying + protected SingleMinMaxByNState(SingleMinMaxByNState state) + { + this.heapFactory = state.heapFactory; + this.deserializer = state.deserializer; + + if (state.typedHeap != null) { + this.typedHeap = state.typedHeap.copy(); + } + else { + this.typedHeap = null; + } + } + + @Override + public abstract AccumulatorState copy(); + + @Override + public final long getEstimatedSize() + { + return INSTANCE_SIZE + (typedHeap == null ? 0 : typedHeap.getEstimatedSize()); + } + + @Override + public final void initialize(long n) + { + if (typedHeap == null) { + typedHeap = heapFactory.apply(n); + } + } + + @Override + public final void add(Block keyBlock, Block valueBlock, int position) + { + typedHeap.add(keyBlock, valueBlock, position); + } + + @Override + public final void merge(MinMaxByNState other) + { + TypedKeyValueHeap otherTypedHeap = ((AbstractMinMaxByNState) other).getTypedKeyValueHeap(); + if (otherTypedHeap == null) { + return; + } + if (typedHeap == null) { + typedHeap = otherTypedHeap; + } + else { + typedHeap.addAll(otherTypedHeap); + } + } + + @Override + public final void popAll(BlockBuilder out) + { + if (typedHeap == null || typedHeap.isEmpty()) { + out.appendNull(); + return; + } + + BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); + typedHeap.popAllReverse(arrayBlockBuilder); + out.closeEntry(); + } + + @Override + public final void serialize(BlockBuilder out) + { + if (typedHeap == null) { + out.appendNull(); + } + else { + typedHeap.serialize(out); + } + } + + @Override + public final void deserialize(Block rowBlock) + { + typedHeap = deserializer.apply(rowBlock); + } + + @Override + final TypedKeyValueHeap getTypedKeyValueHeap() + { + return typedHeap; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateSerializer.java new file mode 100644 index 000000000000..66219cbd6c6a --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateSerializer.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxbyn; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.Type; + +public abstract class MinMaxByNStateSerializer + implements AccumulatorStateSerializer +{ + private final Type serializedType; + + public MinMaxByNStateSerializer(Type serializedType) + { + this.serializedType = serializedType; + } + + @Override + public Type getSerializedType() + { + return serializedType; + } + + @Override + public void serialize(T state, BlockBuilder out) + { + state.serialize(out); + } + + @Override + public void deserialize(Block block, int index, T state) + { + Block rowBlock = (Block) serializedType.getObject(block, index); + state.deserialize(rowBlock); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedKeyValueHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java similarity index 75% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/TypedKeyValueHeap.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java index 41e3d01210a8..b3c5c697b152 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedKeyValueHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation; +package io.trino.operator.aggregation.minmaxbyn; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; @@ -28,6 +28,7 @@ import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.spi.type.BigintType.BIGINT; import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; public class TypedKeyValueHeap { @@ -36,7 +37,8 @@ public class TypedKeyValueHeap private static final int COMPACT_THRESHOLD_BYTES = 32768; private static final int COMPACT_THRESHOLD_RATIO = 3; // when 2/3 of elements in keyBlockBuilder is unreferenced, do compact - private final MethodHandle keyGreaterThan; + private final boolean min; + private final MethodHandle compare; private final Type keyType; private final Type valueType; private final int capacity; @@ -46,11 +48,12 @@ public class TypedKeyValueHeap private BlockBuilder keyBlockBuilder; private BlockBuilder valueBlockBuilder; - public TypedKeyValueHeap(MethodHandle keyGreaterThan, Type keyType, Type valueType, int capacity) + public TypedKeyValueHeap(boolean min, MethodHandle compare, Type keyType, Type valueType, int capacity) { - this.keyGreaterThan = keyGreaterThan; - this.keyType = keyType; - this.valueType = valueType; + this.min = min; + this.compare = requireNonNull(compare, "compare is null"); + this.keyType = requireNonNull(keyType, "keyType is null"); + this.valueType = requireNonNull(valueType, "valueType is null"); this.capacity = capacity; this.heapIndex = new int[capacity]; this.keyBlockBuilder = keyType.createBlockBuilder(null, capacity); @@ -58,11 +61,12 @@ public TypedKeyValueHeap(MethodHandle keyGreaterThan, Type keyType, Type valueTy } // for copying - private TypedKeyValueHeap(MethodHandle keyGreaterThan, Type keyType, Type valueType, int capacity, int positionCount, int[] heapIndex, BlockBuilder keyBlockBuilder, BlockBuilder valueBlockBuilder) + private TypedKeyValueHeap(boolean min, MethodHandle compare, Type keyType, Type valueType, int capacity, int positionCount, int[] heapIndex, BlockBuilder keyBlockBuilder, BlockBuilder valueBlockBuilder) { - this.keyGreaterThan = keyGreaterThan; - this.keyType = keyType; - this.valueType = valueType; + this.min = min; + this.compare = requireNonNull(compare, "compare is null"); + this.keyType = requireNonNull(keyType, "keyType is null"); + this.valueType = requireNonNull(valueType, "valueType is null"); this.capacity = capacity; this.positionCount = positionCount; this.heapIndex = heapIndex; @@ -110,14 +114,46 @@ public void serialize(BlockBuilder out) out.closeEntry(); } - public static TypedKeyValueHeap deserialize(Block block, Type keyType, Type valueType, MethodHandle keyComparisonOperator) + public static TypedKeyValueHeap deserialize(boolean min, MethodHandle compare, Type keyType, Type valueType, Block rowBlock) { - int capacity = toIntExact(BIGINT.getLong(block, 0)); - Block keysBlock = new ArrayType(keyType).getObject(block, 1); - Block valuesBlock = new ArrayType(valueType).getObject(block, 2); - TypedKeyValueHeap heap = new TypedKeyValueHeap(keyComparisonOperator, keyType, valueType, capacity); - heap.addAll(keysBlock, valuesBlock); - return heap; + int capacity = toIntExact(BIGINT.getLong(rowBlock, 0)); + int[] heapIndex = new int[capacity]; + + BlockBuilder keyBlockBuilder = keyType.createBlockBuilder(null, capacity); + Block keyBlock = new ArrayType(keyType).getObject(rowBlock, 1); + for (int position = 0; position < keyBlock.getPositionCount(); position++) { + heapIndex[position] = position; + keyType.appendTo(keyBlock, position, keyBlockBuilder); + } + + BlockBuilder valueBlockBuilder = valueType.createBlockBuilder(null, capacity); + Block valueBlock = new ArrayType(valueType).getObject(rowBlock, 2); + for (int position = 0; position < valueBlock.getPositionCount(); position++) { + heapIndex[position] = position; + if (valueBlock.isNull(position)) { + valueBlockBuilder.appendNull(); + } + else { + valueType.appendTo(valueBlock, position, valueBlockBuilder); + } + } + + return new TypedKeyValueHeap(min, compare, keyType, valueType, capacity, keyBlock.getPositionCount(), heapIndex, keyBlockBuilder, valueBlockBuilder); + } + + public void popAllReverse(BlockBuilder resultBlockBuilder) + { + int[] indexes = new int[positionCount]; + while (positionCount > 0) { + indexes[positionCount - 1] = heapIndex[0]; + positionCount--; + heapIndex[0] = heapIndex[positionCount]; + siftDown(); + } + + for (int index : indexes) { + valueType.appendTo(valueBlockBuilder, index, resultBlockBuilder); + } } public void popAll(BlockBuilder resultBlockBuilder) @@ -237,10 +273,8 @@ private void compactIfNecessary() private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) { try { - // Swap the argument order to get a less than operator, and negate the result to get greater than or equals. - // Note: the keyGreaterThan operator is based comparison, and specifically is not a pure greater than operator. - // This means negation of the result is safe for unordered values. - return !((boolean) keyGreaterThan.invokeExact(rightBlock, rightPosition, leftBlock, leftPosition)); + long result = (long) compare.invokeExact(leftBlock, leftPosition, rightBlock, rightPosition); + return min ? result <= 0 : result >= 0; } catch (Throwable throwable) { Throwables.throwIfUnchecked(throwable); @@ -259,7 +293,8 @@ public TypedKeyValueHeap copy() valueBlockBuilderCopy = (BlockBuilder) valueBlockBuilder.copyRegion(0, valueBlockBuilder.getPositionCount()); } return new TypedKeyValueHeap( - keyGreaterThan, + min, + compare, keyType, valueType, capacity, diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java new file mode 100644 index 000000000000..ef6b86d6cd89 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +@AggregationFunction("max") +@Description("Returns the maximum values of the argument") +public final class MaxNAggregationFunction +{ + private MaxNAggregationFunction() {} + + @InputFunction + @TypeParameter("E") + public static void input( + @AggregationState("E") MaxNState state, + @BlockPosition @SqlType("E") Block block, + @SqlType("BIGINT") long n, + @BlockIndex int blockIndex) + { + state.initialize(n); + state.add(block, blockIndex); + } + + @CombineFunction + public static void combine( + @AggregationState("E") MaxNState state, + @AggregationState("E") MaxNState otherState) + { + state.merge(otherState); + } + + @OutputFunction("array(E)") + public static void output(@AggregationState("E") MaxNState state, BlockBuilder out) + { + state.popAll(out); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNState.java new file mode 100644 index 000000000000..9617386665d6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNState.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.spi.function.AccumulatorStateMetadata; + +@AccumulatorStateMetadata( + stateFactoryClass = MaxNStateFactory.class, + stateSerializerClass = MaxNStateSerializer.class, + typeParameters = "T", + serializedType = "ROW(BIGINT, ARRAY(T))") +public interface MaxNState + extends MinMaxNState {} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java new file mode 100644 index 000000000000..448b78b0cccf --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.operator.aggregation.minmaxn.MinMaxNStateFactory.SingleMinMaxNState; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.Convention; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; + +import java.lang.invoke.MethodHandle; +import java.util.function.Function; +import java.util.function.LongFunction; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.util.Failures.checkCondition; +import static java.lang.Math.toIntExact; + +public class MaxNStateFactory + implements AccumulatorStateFactory +{ + private static final long MAX_NUMBER_OF_VALUES = 10_000; + private final LongFunction heapFactory; + private final Function deserializer; + + public MaxNStateFactory( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"T", "T"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) + MethodHandle compare, + @TypeParameter("T") Type elementType) + { + heapFactory = n -> { + checkCondition(n > 0, INVALID_FUNCTION_ARGUMENT, "second argument of max_n must be positive"); + checkCondition( + n <= MAX_NUMBER_OF_VALUES, + INVALID_FUNCTION_ARGUMENT, + "second argument of max_n must be less than or equal to %s; found %s", + MAX_NUMBER_OF_VALUES, + n); + return new TypedHeap(false, compare, elementType, toIntExact(n)); + }; + deserializer = rowBlock -> TypedHeap.deserialize(false, compare, elementType, rowBlock); + } + + @Override + public MaxNState createSingleState() + { + return new SingleMaxNState(heapFactory, deserializer); + } + + @Override + public MaxNState createGroupedState() + { + return new GroupedMaxNState(heapFactory, deserializer); + } + + private static class GroupedMaxNState + extends MinMaxNStateFactory.GroupedMinMaxNState + implements MaxNState + { + public GroupedMaxNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + } + + private static class SingleMaxNState + extends SingleMinMaxNState + implements MaxNState + { + public SingleMaxNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + + public SingleMaxNState(SingleMaxNState state) + { + super(state); + } + + @Override + public AccumulatorState copy() + { + return new SingleMaxNState(this); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateSerializer.java similarity index 62% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNState.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateSerializer.java index 6281d39d8688..c92073a13ddd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateSerializer.java @@ -11,17 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation.state; +package io.trino.operator.aggregation.minmaxn; -import io.trino.operator.aggregation.TypedHeap; -import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; -public interface MinMaxNState - extends AccumulatorState +public class MaxNStateSerializer + extends MinMaxNStateSerializer { - TypedHeap getTypedHeap(); - - void setTypedHeap(TypedHeap value); - - void addMemoryUsage(long memory); + public MaxNStateSerializer(@TypeParameter("ROW(BIGINT, ARRAY(T))") Type serializedType) + { + super(serializedType); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java new file mode 100644 index 000000000000..2860f00a53f0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorState; + +public interface MinMaxNState + extends AccumulatorState +{ + /** + * Initialize the state if not already initialized. Only the first call is processed and + * all subsequent calls are ignored. + */ + void initialize(long n); + + /** + * Adds the value to this state. + */ + void add(Block block, int position); + + /** + * Merge with the specified state. + * The supplied state should not be used after this method is called, because + * the internal details of the state may be reused in this state. + */ + void merge(MinMaxNState other); + + /** + * Writes all values to the supplied block builder as an array entry. + * After this method is called, the current state will be empty. + */ + void popAll(BlockBuilder out); + + /** + * Write this state to the specified block builder. + */ + void serialize(BlockBuilder out); + + /** + * Read the state to the specified block builder. + * + * @throws IllegalStateException if state is already initialized + */ + void deserialize(Block rowBlock); +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java new file mode 100644 index 000000000000..c69503a107cc --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java @@ -0,0 +1,271 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.array.ObjectBigArray; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.GroupedAccumulatorState; +import org.openjdk.jol.info.ClassLayout; + +import java.util.function.Function; +import java.util.function.LongFunction; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public final class MinMaxNStateFactory +{ + private MinMaxNStateFactory() {} + + private abstract static class AbstractMinMaxNState + implements MinMaxNState + { + abstract TypedHeap getTypedHeap(); + } + + public abstract static class GroupedMinMaxNState + extends AbstractMinMaxNState + implements GroupedAccumulatorState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedMinMaxNState.class).instanceSize(); + + private final LongFunction heapFactory; + private final Function deserializer; + + private final ObjectBigArray heaps = new ObjectBigArray<>(); + private long groupId; + private long size; + + public GroupedMinMaxNState(LongFunction heapFactory, Function deserializer) + { + this.heapFactory = heapFactory; + this.deserializer = deserializer; + } + + @Override + public final void setGroupId(long groupId) + { + this.groupId = groupId; + } + + @Override + public final void ensureCapacity(long size) + { + heaps.ensureCapacity(size); + } + + @Override + public final long getEstimatedSize() + { + return INSTANCE_SIZE + heaps.sizeOf() + size; + } + + @Override + public final void initialize(long n) + { + if (getTypedHeap() == null) { + TypedHeap typedHeap = heapFactory.apply(n); + setTypedHeap(typedHeap); + size += typedHeap.getEstimatedSize(); + } + } + + @Override + public final void add(Block block, int position) + { + TypedHeap typedHeap = getTypedHeap(); + + size -= typedHeap.getEstimatedSize(); + typedHeap.add(block, position); + size += typedHeap.getEstimatedSize(); + } + + @Override + public final void merge(MinMaxNState other) + { + TypedHeap otherTypedHeap = ((AbstractMinMaxNState) other).getTypedHeap(); + if (otherTypedHeap == null) { + return; + } + + TypedHeap typedHeap = getTypedHeap(); + if (typedHeap == null) { + setTypedHeap(otherTypedHeap); + size += otherTypedHeap.getEstimatedSize(); + } + else { + size -= typedHeap.getEstimatedSize(); + typedHeap.addAll(otherTypedHeap); + size += typedHeap.getEstimatedSize(); + } + } + + @Override + public final void popAll(BlockBuilder out) + { + TypedHeap typedHeap = getTypedHeap(); + if (typedHeap == null || typedHeap.isEmpty()) { + out.appendNull(); + return; + } + + BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); + + size -= typedHeap.getEstimatedSize(); + typedHeap.popAllReverse(arrayBlockBuilder); + size += typedHeap.getEstimatedSize(); + + out.closeEntry(); + } + + @Override + public final void serialize(BlockBuilder out) + { + TypedHeap typedHeap = getTypedHeap(); + if (typedHeap == null) { + out.appendNull(); + } + else { + typedHeap.serialize(out); + } + } + + @Override + public final void deserialize(Block rowBlock) + { + checkState(getTypedHeap() == null, "State already initialized"); + + TypedHeap typedHeap = deserializer.apply(rowBlock); + setTypedHeap(typedHeap); + size += typedHeap.getEstimatedSize(); + } + + @Override + final TypedHeap getTypedHeap() + { + return heaps.get(groupId); + } + + private void setTypedHeap(TypedHeap value) + { + heaps.set(groupId, value); + } + } + + public abstract static class SingleMinMaxNState + extends AbstractMinMaxNState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMinMaxNState.class).instanceSize(); + + private final LongFunction heapFactory; + private final Function deserializer; + + private TypedHeap typedHeap; + + public SingleMinMaxNState(LongFunction heapFactory, Function deserializer) + { + this.heapFactory = requireNonNull(heapFactory, "heapFactory is null"); + this.deserializer = requireNonNull(deserializer, "deserializer is null"); + } + + protected SingleMinMaxNState(SingleMinMaxNState state) + { + this.heapFactory = state.heapFactory; + this.deserializer = state.deserializer; + + if (state.typedHeap != null) { + this.typedHeap = state.typedHeap.copy(); + } + else { + this.typedHeap = null; + } + } + + @Override + public abstract AccumulatorState copy(); + + @Override + public final long getEstimatedSize() + { + return INSTANCE_SIZE + (typedHeap == null ? 0 : typedHeap.getEstimatedSize()); + } + + @Override + public final void initialize(long n) + { + if (typedHeap == null) { + typedHeap = heapFactory.apply(n); + } + } + + @Override + public final void add(Block block, int position) + { + typedHeap.add(block, position); + } + + @Override + public final void merge(MinMaxNState other) + { + TypedHeap otherTypedHeap = ((AbstractMinMaxNState) other).getTypedHeap(); + if (otherTypedHeap == null) { + return; + } + if (typedHeap == null) { + typedHeap = otherTypedHeap; + } + else { + typedHeap.addAll(otherTypedHeap); + } + } + + @Override + public final void popAll(BlockBuilder out) + { + if (typedHeap == null || typedHeap.isEmpty()) { + out.appendNull(); + return; + } + + BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); + typedHeap.popAllReverse(arrayBlockBuilder); + out.closeEntry(); + } + + @Override + public final void serialize(BlockBuilder out) + { + if (typedHeap == null) { + out.appendNull(); + } + else { + typedHeap.serialize(out); + } + } + + @Override + public final void deserialize(Block rowBlock) + { + typedHeap = deserializer.apply(rowBlock); + } + + @Override + final TypedHeap getTypedHeap() + { + return typedHeap; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateSerializer.java new file mode 100644 index 000000000000..fc87847dff7c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateSerializer.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.Type; + +import static java.util.Objects.requireNonNull; + +public abstract class MinMaxNStateSerializer + implements AccumulatorStateSerializer +{ + private final Type serializedType; + + public MinMaxNStateSerializer(Type serializedType) + { + this.serializedType = requireNonNull(serializedType, "serializedType is null"); + } + + @Override + public Type getSerializedType() + { + return serializedType; + } + + @Override + public void serialize(T state, BlockBuilder out) + { + state.serialize(out); + } + + @Override + public void deserialize(Block block, int index, T state) + { + Block rowBlock = (Block) serializedType.getObject(block, index); + state.deserialize(rowBlock); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java new file mode 100644 index 000000000000..aae0eefced43 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +@AggregationFunction("min") +@Description("Returns the minimum values of the argument") +public final class MinNAggregationFunction +{ + private MinNAggregationFunction() {} + + @InputFunction + @TypeParameter("E") + public static void input( + @AggregationState("E") MinNState state, + @BlockPosition @SqlType("E") Block block, + @SqlType("BIGINT") long n, + @BlockIndex int blockIndex) + { + state.initialize(n); + state.add(block, blockIndex); + } + + @CombineFunction + public static void combine( + @AggregationState("E") MinNState state, + @AggregationState("E") MinNState otherState) + { + state.merge(otherState); + } + + @OutputFunction("array(E)") + public static void output(@AggregationState("E") MinNState state, BlockBuilder out) + { + state.popAll(out); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNState.java new file mode 100644 index 000000000000..0834279eb3a7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNState.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.spi.function.AccumulatorStateMetadata; + +@AccumulatorStateMetadata( + stateFactoryClass = MinNStateFactory.class, + stateSerializerClass = MinNStateSerializer.class, + typeParameters = "T", + serializedType = "ROW(BIGINT, ARRAY(T))") +public interface MinNState + extends MinMaxNState +{} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java new file mode 100644 index 000000000000..fa432a9ce66b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.minmaxn; + +import io.trino.operator.aggregation.minmaxn.MinMaxNStateFactory.SingleMinMaxNState; +import io.trino.spi.block.Block; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.Convention; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; + +import java.lang.invoke.MethodHandle; +import java.util.function.Function; +import java.util.function.LongFunction; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.util.Failures.checkCondition; +import static java.lang.Math.toIntExact; + +public class MinNStateFactory + implements AccumulatorStateFactory +{ + private static final long MAX_NUMBER_OF_VALUES = 10_000; + private final LongFunction heapFactory; + private final Function deserializer; + + public MinNStateFactory( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"T", "T"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) + MethodHandle compare, + @TypeParameter("T") Type elementType) + { + heapFactory = n -> { + checkCondition(n > 0, INVALID_FUNCTION_ARGUMENT, "second argument of min_n must be positive"); + checkCondition( + n <= MAX_NUMBER_OF_VALUES, + INVALID_FUNCTION_ARGUMENT, + "second argument of min_n must be less than or equal to %s; found %s", + MAX_NUMBER_OF_VALUES, + n); + return new TypedHeap(true, compare, elementType, toIntExact(n)); + }; + deserializer = rowBlock -> TypedHeap.deserialize(true, compare, elementType, rowBlock); + } + + @Override + public MinNState createSingleState() + { + return new SingleMinNState(heapFactory, deserializer); + } + + @Override + public MinNState createGroupedState() + { + return new GroupedMinNState(heapFactory, deserializer); + } + + private static class GroupedMinNState + extends MinMaxNStateFactory.GroupedMinMaxNState + implements MinNState + { + public GroupedMinNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + } + + private static class SingleMinNState + extends SingleMinMaxNState + implements MinNState + { + public SingleMinNState(LongFunction heapFactory, Function deserializer) + { + super(heapFactory, deserializer); + } + + public SingleMinNState(SingleMinNState state) + { + super(state); + } + + @Override + public AccumulatorState copy() + { + return new SingleMinNState(this); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateSerializer.java similarity index 61% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNState.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateSerializer.java index 7dbfa5a7db67..3762a859498c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinMaxByNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateSerializer.java @@ -11,17 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation.minmaxby; +package io.trino.operator.aggregation.minmaxn; -import io.trino.operator.aggregation.TypedKeyValueHeap; -import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; -public interface MinMaxByNState - extends AccumulatorState +public class MinNStateSerializer + extends MinMaxNStateSerializer { - TypedKeyValueHeap getTypedKeyValueHeap(); - - void setTypedKeyValueHeap(TypedKeyValueHeap value); - - void addMemoryUsage(long memory); + public MinNStateSerializer(@TypeParameter("ROW(BIGINT, ARRAY(T))") Type serializedType) + { + super(serializedType); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java similarity index 63% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/TypedHeap.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java index cd5f6d2d2439..62e073d04f23 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java @@ -11,11 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.aggregation; +package io.trino.operator.aggregation.minmaxn; import com.google.common.base.Throwables; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import org.openjdk.jol.info.ClassLayout; @@ -23,6 +24,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.sizeOf; +import static io.trino.spi.type.BigintType.BIGINT; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; public class TypedHeap { @@ -31,28 +35,31 @@ public class TypedHeap private static final int COMPACT_THRESHOLD_BYTES = 32768; private static final int COMPACT_THRESHOLD_RATIO = 3; // when 2/3 of elements in heapBlockBuilder is unreferenced, do compact - private final MethodHandle greaterThanMethod; - private final Type type; + private final boolean min; + private final MethodHandle compare; + private final Type elementType; private final int capacity; private int positionCount; private final int[] heapIndex; private BlockBuilder heapBlockBuilder; - public TypedHeap(MethodHandle greaterThanMethod, Type type, int capacity) + public TypedHeap(boolean min, MethodHandle compare, Type elementType, int capacity) { - this.greaterThanMethod = greaterThanMethod; - this.type = type; + this.min = min; + this.compare = requireNonNull(compare, "compare is null"); + this.elementType = requireNonNull(elementType, "elementType is null"); this.capacity = capacity; this.heapIndex = new int[capacity]; - this.heapBlockBuilder = type.createBlockBuilder(null, capacity); + this.heapBlockBuilder = elementType.createBlockBuilder(null, capacity); } // for copying - private TypedHeap(MethodHandle greaterThanMethod, Type type, int capacity, int positionCount, int[] heapIndex, BlockBuilder heapBlockBuilder) + private TypedHeap(boolean min, MethodHandle compare, Type elementType, int capacity, int positionCount, int[] heapIndex, BlockBuilder heapBlockBuilder) { - this.greaterThanMethod = greaterThanMethod; - this.type = type; + this.min = min; + this.compare = requireNonNull(compare, "compare is null"); + this.elementType = requireNonNull(elementType, "elementType is null"); this.capacity = capacity; this.positionCount = positionCount; this.heapIndex = heapIndex; @@ -66,7 +73,7 @@ public int getCapacity() public long getEstimatedSize() { - return INSTANCE_SIZE + heapBlockBuilder.getRetainedSizeInBytes() + sizeOf(heapIndex); + return INSTANCE_SIZE + (heapBlockBuilder == null ? 0 : heapBlockBuilder.getRetainedSizeInBytes()) + sizeOf(heapIndex); } public boolean isEmpty() @@ -74,10 +81,48 @@ public boolean isEmpty() return positionCount == 0; } - public void writeAll(BlockBuilder resultBlockBuilder) + public void serialize(BlockBuilder out) { + BlockBuilder blockBuilder = out.beginBlockEntry(); + BIGINT.writeLong(blockBuilder, capacity); + + BlockBuilder elements = blockBuilder.beginBlockEntry(); for (int i = 0; i < positionCount; i++) { - type.appendTo(heapBlockBuilder, heapIndex[i], resultBlockBuilder); + elementType.appendTo(heapBlockBuilder, heapIndex[i], elements); + } + blockBuilder.closeEntry(); + + out.closeEntry(); + } + + public static TypedHeap deserialize(boolean min, MethodHandle compare, Type elementType, Block rowBlock) + { + int capacity = toIntExact(BIGINT.getLong(rowBlock, 0)); + int[] heapIndex = new int[capacity]; + + BlockBuilder heapBlockBuilder = elementType.createBlockBuilder(null, capacity); + + Block heapBlock = new ArrayType(elementType).getObject(rowBlock, 1); + for (int position = 0; position < heapBlock.getPositionCount(); position++) { + heapIndex[position] = position; + elementType.appendTo(heapBlock, position, heapBlockBuilder); + } + + return new TypedHeap(min, compare, elementType, capacity, heapBlock.getPositionCount(), heapIndex, heapBlockBuilder); + } + + public void popAllReverse(BlockBuilder resultBlockBuilder) + { + int[] indexes = new int[positionCount]; + while (positionCount > 0) { + indexes[positionCount - 1] = heapIndex[0]; + positionCount--; + heapIndex[0] = heapIndex[positionCount]; + siftDown(); + } + + for (int index : indexes) { + elementType.appendTo(heapBlockBuilder, index, resultBlockBuilder); } } @@ -90,7 +135,7 @@ public void popAll(BlockBuilder resultBlockBuilder) public void pop(BlockBuilder resultBlockBuilder) { - type.appendTo(heapBlockBuilder, heapIndex[0], resultBlockBuilder); + elementType.appendTo(heapBlockBuilder, heapIndex[0], resultBlockBuilder); remove(); } @@ -105,17 +150,17 @@ public void add(Block block, int position) { checkArgument(!block.isNull(position)); if (positionCount == capacity) { - if (keyGreaterThanOrEqual(this.heapBlockBuilder, heapIndex[0], block, position)) { + if (keyGreaterThanOrEqual(heapBlockBuilder, heapIndex[0], block, position)) { return; // and new element is not larger than heap top: do not add } heapIndex[0] = heapBlockBuilder.getPositionCount(); - type.appendTo(block, position, heapBlockBuilder); + elementType.appendTo(block, position, heapBlockBuilder); siftDown(); } else { heapIndex[positionCount] = heapBlockBuilder.getPositionCount(); positionCount++; - type.appendTo(block, position, heapBlockBuilder); + elementType.appendTo(block, position, heapBlockBuilder); siftUp(); } compactIfNecessary(); @@ -178,15 +223,15 @@ private void siftUp() private void compactIfNecessary() { - // Byte size check is needed. Otherwise, if size * 3 is small, BlockBuilder can be reallocate too often. + // Byte size check is needed. Otherwise, if size * 3 is small, BlockBuilder can be reallocated too often. // Position count is needed. Otherwise, for large elements, heap will be compacted every time. // Size instead of retained size is needed because default allocation size can be huge for some block builders. And the first check will become useless in such case. if (heapBlockBuilder.getSizeInBytes() < COMPACT_THRESHOLD_BYTES || heapBlockBuilder.getPositionCount() / positionCount < COMPACT_THRESHOLD_RATIO) { return; } - BlockBuilder newHeapBlockBuilder = type.createBlockBuilder(null, heapBlockBuilder.getPositionCount()); + BlockBuilder newHeapBlockBuilder = elementType.createBlockBuilder(null, heapBlockBuilder.getPositionCount()); for (int i = 0; i < positionCount; i++) { - type.appendTo(heapBlockBuilder, heapIndex[i], newHeapBlockBuilder); + elementType.appendTo(heapBlockBuilder, heapIndex[i], newHeapBlockBuilder); heapIndex[i] = i; } heapBlockBuilder = newHeapBlockBuilder; @@ -195,8 +240,8 @@ private void compactIfNecessary() private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) { try { - // this is a greater than operator, so we swap the object order and not the result - return !((boolean) greaterThanMethod.invokeExact(rightBlock, rightPosition, leftBlock, leftPosition)); + long result = (long) compare.invokeExact(leftBlock, leftPosition, rightBlock, rightPosition); + return min ? result < 0 : result > 0; } catch (Throwable throwable) { Throwables.throwIfUnchecked(throwable); @@ -210,6 +255,6 @@ public TypedHeap copy() if (heapBlockBuilder != null) { heapBlockBuilderCopy = (BlockBuilder) heapBlockBuilder.copyRegion(0, heapBlockBuilder.getPositionCount()); } - return new TypedHeap(greaterThanMethod, type, capacity, positionCount, heapIndex.clone(), heapBlockBuilderCopy); + return new TypedHeap(min, compare, elementType, capacity, positionCount, heapIndex.clone(), heapBlockBuilderCopy); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java index 7cb2d84ba615..986d03094142 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java @@ -13,121 +13,79 @@ */ package io.trino.operator.aggregation.multimapagg; -import com.google.common.collect.ImmutableList; import io.trino.array.ObjectBigArray; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; +import io.trino.operator.aggregation.NullablePosition; import io.trino.operator.aggregation.TypedSet; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.type.BlockTypeOperators; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.util.Optional; - import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; -import static io.trino.spi.type.TypeSignature.arrayType; -import static io.trino.spi.type.TypeSignature.mapType; -import static io.trino.spi.type.TypeSignature.rowType; -import static io.trino.spi.type.TypeSignatureParameter.anonymousField; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.type.TypeUtils.expectedValueSize; -import static io.trino.util.Reflection.methodHandle; -import static java.util.Objects.requireNonNull; -public class MultimapAggregationFunction - extends SqlAggregationFunction +@AggregationFunction(value = "multimap_agg", isOrderSensitive = true) +@Description("Aggregates all the rows (key/value pairs) into a single multimap") +public final class MultimapAggregationFunction { - public static final String NAME = "multimap_agg"; - private static final MethodHandle OUTPUT_FUNCTION = methodHandle( - MultimapAggregationFunction.class, - "output", - Type.class, - BlockPositionIsDistinctFrom.class, - BlockPositionHashCode.class, - Type.class, - MultimapAggregationState.class, - BlockBuilder.class); - private static final MethodHandle COMBINE_FUNCTION = methodHandle( - MultimapAggregationFunction.class, - "combine", - MultimapAggregationState.class, - MultimapAggregationState.class); - private static final MethodHandle INPUT_FUNCTION = methodHandle( - MultimapAggregationFunction.class, - "input", - MultimapAggregationState.class, - Block.class, - Block.class, - int.class); private static final int EXPECTED_ENTRY_SIZE = 100; - private final BlockTypeOperators blockTypeOperators; - public MultimapAggregationFunction(BlockTypeOperators blockTypeOperators) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name(NAME) - .comparableTypeParameter("K") - .typeVariable("V") - .returnType(mapType(new TypeSignature("K"), arrayType(new TypeSignature("V")))) - .argumentType(new TypeSignature("K")) - .argumentType(new TypeSignature("V")) - .build()) - .argumentNullability(false, true) - .description("Aggregates all the rows (key/value pairs) into a single multimap") - .build(), - AggregationFunctionMetadata.builder() - .orderSensitive() - .intermediateType(arrayType(rowType(anonymousField(new TypeSignature("V")), anonymousField(new TypeSignature("K"))))) - .build()); - this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature) - { - Type keyType = boundSignature.getArgumentType(0); - BlockPositionIsDistinctFrom keyDistinctOperator = blockTypeOperators.getDistinctFromOperator(keyType); - BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType); - - Type valueType = boundSignature.getArgumentType(1); - - MultimapAggregationStateSerializer stateSerializer = new MultimapAggregationStateSerializer(keyType, valueType); - - return new AggregationMetadata( - INPUT_FUNCTION, - Optional.empty(), - Optional.of(COMBINE_FUNCTION), - MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyDistinctOperator, keyHashCode, valueType), - ImmutableList.of(new AccumulatorStateDescriptor<>( - MultimapAggregationState.class, - stateSerializer, - new MultimapAggregationStateFactory(keyType, valueType)))); - } + private MultimapAggregationFunction() {} - public static void input(MultimapAggregationState state, Block key, Block value, int position) + @InputFunction + @TypeParameter("K") + @TypeParameter("V") + public static void input( + @AggregationState({"K", "V"}) MultimapAggregationState state, + @BlockPosition @SqlType("K") Block key, + @NullablePosition @BlockPosition @SqlType("V") Block value, + @BlockIndex int position) { state.add(key, value, position); } - public static void combine(MultimapAggregationState state, MultimapAggregationState otherState) + @CombineFunction + public static void combine( + @AggregationState({"K", "V"}) MultimapAggregationState state, + @AggregationState({"K", "V"}) MultimapAggregationState otherState) { state.merge(otherState); } - public static void output(Type keyType, BlockPositionIsDistinctFrom keyDistinctOperator, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out) + @OutputFunction("map(K, array(V))") + public static void output( + @TypeParameter("K") Type keyType, + @OperatorDependency( + operator = OperatorType.IS_DISTINCT_FROM, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) + BlockPositionIsDistinctFrom keyDistinctFrom, + @OperatorDependency( + operator = OperatorType.HASH_CODE, + argumentTypes = "K", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) + BlockPositionHashCode keyHashCode, + @TypeParameter("V") Type valueType, + @AggregationState({"K", "V"}) MultimapAggregationState state, + BlockBuilder out) { if (state.isEmpty()) { out.appendNull(); @@ -137,7 +95,7 @@ public static void output(Type keyType, BlockPositionIsDistinctFrom keyDistinctO ObjectBigArray valueArrayBlockBuilders = new ObjectBigArray<>(); valueArrayBlockBuilders.ensureCapacity(state.getEntryCount()); BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, state.getEntryCount(), expectedValueSize(keyType, 100)); - TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctOperator, keyHashCode, state.getEntryCount(), NAME); + TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctFrom, keyHashCode, state.getEntryCount(), "multimap_agg"); state.forEach((key, value, keyValueIndex) -> { // Merge values of the same key into an array diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java index 5aeb0a431b40..97aeb9307962 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java @@ -17,7 +17,11 @@ import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; -@AccumulatorStateMetadata(stateFactoryClass = MultimapAggregationStateFactory.class, stateSerializerClass = MultimapAggregationStateSerializer.class) +@AccumulatorStateMetadata( + stateFactoryClass = MultimapAggregationStateFactory.class, + stateSerializerClass = MultimapAggregationStateSerializer.class, + typeParameters = {"K", "V"}, + serializedType = "array(row(V, K))") public interface MultimapAggregationState extends AccumulatorState { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java index 218190ae25a3..8be1645f5ae3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java @@ -14,6 +14,7 @@ package io.trino.operator.aggregation.multimapagg; import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; import static java.util.Objects.requireNonNull; @@ -24,7 +25,7 @@ public class MultimapAggregationStateFactory private final Type keyType; private final Type valueType; - public MultimapAggregationStateFactory(Type keyType, Type valueType) + public MultimapAggregationStateFactory(@TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType) { this.keyType = requireNonNull(keyType); this.valueType = requireNonNull(valueType); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java index 0c7916656090..e3acd8549b8b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java @@ -18,6 +18,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.ColumnarRow; import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; @@ -34,7 +35,7 @@ public class MultimapAggregationStateSerializer private final Type valueType; private final ArrayType arrayType; - public MultimapAggregationStateSerializer(Type keyType, Type valueType) + public MultimapAggregationStateSerializer(@TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType) { this.keyType = requireNonNull(keyType); this.valueType = requireNonNull(valueType); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/InOutStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/InOutStateSerializer.java new file mode 100644 index 000000000000..43acbcae860e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/InOutStateSerializer.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.state; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.InOut; +import io.trino.spi.type.Type; + +import static java.util.Objects.requireNonNull; + +public final class InOutStateSerializer + implements AccumulatorStateSerializer +{ + private final Type type; + + public InOutStateSerializer(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + @Override + public Type getSerializedType() + { + return type; + } + + @Override + public void serialize(InOut state, BlockBuilder out) + { + state.get(out); + } + + @Override + public void deserialize(Block block, int index, InOut state) + { + state.set(block, index); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairStateSerializer.java index fe0ae2b16b3f..91c1ea0c57a5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairStateSerializer.java @@ -17,23 +17,42 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.MapType; +import io.trino.spi.function.Convention; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static java.util.Objects.requireNonNull; + public class KeyValuePairStateSerializer implements AccumulatorStateSerializer { - private final MapType mapType; + private final Type mapType; private final BlockPositionEqual keyEqual; private final BlockPositionHashCode keyHashCode; - public KeyValuePairStateSerializer(MapType mapType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode) + public KeyValuePairStateSerializer( + @TypeParameter("MAP(K, V)") Type mapType, + @OperatorDependency( + operator = OperatorType.EQUAL, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) + BlockPositionEqual keyEqual, + @OperatorDependency( + operator = OperatorType.HASH_CODE, + argumentTypes = "K", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) + BlockPositionHashCode keyHashCode) { - this.mapType = mapType; - this.keyEqual = keyEqual; - this.keyHashCode = keyHashCode; + this.mapType = requireNonNull(mapType, "mapType is null"); + this.keyEqual = requireNonNull(keyEqual, "keyEqual is null"); + this.keyHashCode = requireNonNull(keyHashCode, "keyHashCode is null"); } @Override @@ -56,6 +75,6 @@ public void serialize(KeyValuePairsState state, BlockBuilder out) @Override public void deserialize(Block block, int index, KeyValuePairsState state) { - state.set(new KeyValuePairs(mapType.getObject(block, index), state.getKeyType(), keyEqual, keyHashCode, state.getValueType())); + state.set(new KeyValuePairs((Block) mapType.getObject(block, index), state.getKeyType(), keyEqual, keyHashCode, state.getValueType())); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsState.java index f6ddaa065a1c..3eca35be5de5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsState.java @@ -18,7 +18,11 @@ import io.trino.spi.function.AccumulatorStateMetadata; import io.trino.spi.type.Type; -@AccumulatorStateMetadata(stateFactoryClass = KeyValuePairsStateFactory.class, stateSerializerClass = KeyValuePairStateSerializer.class) +@AccumulatorStateMetadata( + stateFactoryClass = KeyValuePairsStateFactory.class, + stateSerializerClass = KeyValuePairStateSerializer.class, + typeParameters = {"K", "V"}, + serializedType = "MAP(K, V)") public interface KeyValuePairsState extends AccumulatorState { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsStateFactory.java index bd2066e2c340..cd5a9fbbcf35 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/KeyValuePairsStateFactory.java @@ -17,6 +17,7 @@ import io.trino.operator.aggregation.KeyValuePairs; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; import org.openjdk.jol.info.ClassLayout; @@ -28,7 +29,7 @@ public class KeyValuePairsStateFactory private final Type keyType; private final Type valueType; - public KeyValuePairsStateFactory(Type keyType, Type valueType) + public KeyValuePairsStateFactory(@TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType) { this.keyType = keyType; this.valueType = valueType; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNStateFactory.java deleted file mode 100644 index 1475a9413021..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNStateFactory.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.state; - -import io.trino.array.ObjectBigArray; -import io.trino.operator.aggregation.TypedHeap; -import io.trino.spi.function.AccumulatorState; -import io.trino.spi.function.AccumulatorStateFactory; -import org.openjdk.jol.info.ClassLayout; - -public class MinMaxNStateFactory - implements AccumulatorStateFactory -{ - @Override - public MinMaxNState createSingleState() - { - return new SingleMinMaxNState(); - } - - @Override - public MinMaxNState createGroupedState() - { - return new GroupedMinMaxNState(); - } - - public static class GroupedMinMaxNState - extends AbstractGroupedAccumulatorState - implements MinMaxNState - { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedMinMaxNState.class).instanceSize(); - private final ObjectBigArray heaps = new ObjectBigArray<>(); - private long size; - - @Override - public void ensureCapacity(long size) - { - heaps.ensureCapacity(size); - } - - @Override - public long getEstimatedSize() - { - return INSTANCE_SIZE + heaps.sizeOf() + size; - } - - @Override - public TypedHeap getTypedHeap() - { - return heaps.get(getGroupId()); - } - - @Override - public void setTypedHeap(TypedHeap value) - { - TypedHeap previous = getTypedHeap(); - if (previous != null) { - size -= previous.getEstimatedSize(); - } - heaps.set(getGroupId(), value); - size += value.getEstimatedSize(); - } - - @Override - public void addMemoryUsage(long memory) - { - size += memory; - } - } - - public static class SingleMinMaxNState - implements MinMaxNState - { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMinMaxNState.class).instanceSize(); - private TypedHeap typedHeap; - - public SingleMinMaxNState() {} - - // for copying - private SingleMinMaxNState(TypedHeap typedHeap) - { - this.typedHeap = typedHeap; - } - - @Override - public long getEstimatedSize() - { - long estimatedSize = INSTANCE_SIZE; - if (typedHeap != null) { - estimatedSize += typedHeap.getEstimatedSize(); - } - return estimatedSize; - } - - @Override - public TypedHeap getTypedHeap() - { - return typedHeap; - } - - @Override - public void setTypedHeap(TypedHeap typedHeap) - { - this.typedHeap = typedHeap; - } - - @Override - public void addMemoryUsage(long memory) - { - } - - @Override - public AccumulatorState copy() - { - TypedHeap typedHeapCopy = null; - if (typedHeap != null) { - typedHeapCopy = typedHeap.copy(); - } - return new SingleMinMaxNState(typedHeapCopy); - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNStateSerializer.java deleted file mode 100644 index 5b85ec9d7194..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/MinMaxNStateSerializer.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.state; - -import com.google.common.collect.ImmutableList; -import io.trino.operator.aggregation.TypedHeap; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.RowType; -import io.trino.spi.type.Type; - -import java.lang.invoke.MethodHandle; - -import static io.trino.spi.type.BigintType.BIGINT; -import static java.lang.Math.toIntExact; - -public class MinMaxNStateSerializer - implements AccumulatorStateSerializer -{ - private final MethodHandle greaterThanMethod; - private final Type elementType; - private final ArrayType arrayType; - private final Type serializedType; - - public MinMaxNStateSerializer(MethodHandle greaterThanMethod, Type elementType) - { - this.greaterThanMethod = greaterThanMethod; - this.elementType = elementType; - this.arrayType = new ArrayType(elementType); - this.serializedType = RowType.anonymous(ImmutableList.of(BIGINT, arrayType)); - } - - @Override - public Type getSerializedType() - { - return serializedType; - } - - @Override - public void serialize(MinMaxNState state, BlockBuilder out) - { - TypedHeap heap = state.getTypedHeap(); - if (heap == null) { - out.appendNull(); - return; - } - - BlockBuilder blockBuilder = out.beginBlockEntry(); - BIGINT.writeLong(blockBuilder, heap.getCapacity()); - BlockBuilder elements = blockBuilder.beginBlockEntry(); - heap.writeAll(elements); - blockBuilder.closeEntry(); - - out.closeEntry(); - } - - @Override - public void deserialize(Block block, int index, MinMaxNState state) - { - Block currentBlock = (Block) serializedType.getObject(block, index); - int capacity = toIntExact(BIGINT.getLong(currentBlock, 0)); - Block heapBlock = arrayType.getObject(currentBlock, 1); - TypedHeap heap = new TypedHeap(greaterThanMethod, elementType, capacity); - heap.addAll(heapBlock); - state.setTypedHeap(heap); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java index 7d8772441603..70f655673a47 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableBooleanStateSerializer.java @@ -13,37 +13,20 @@ */ package io.trino.operator.aggregation.state; -import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.Type; -import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static java.util.Objects.requireNonNull; public class NullableBooleanStateSerializer implements AccumulatorStateSerializer { - private final Type type; - - @UsedByGeneratedCode - public NullableBooleanStateSerializer() - { - this(BOOLEAN); - } - - public NullableBooleanStateSerializer(Type type) - { - this.type = requireNonNull(type, "type is null"); - checkArgument(type.getJavaType() == boolean.class, "Type must use boolean stack type: " + type); - } - @Override public Type getSerializedType() { - return type; + return BOOLEAN; } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java index 381616a6848b..62fee2cc383a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableDoubleStateSerializer.java @@ -13,37 +13,20 @@ */ package io.trino.operator.aggregation.state; -import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.Type; -import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.DoubleType.DOUBLE; -import static java.util.Objects.requireNonNull; public class NullableDoubleStateSerializer implements AccumulatorStateSerializer { - private final Type type; - - @UsedByGeneratedCode - public NullableDoubleStateSerializer() - { - this(DOUBLE); - } - - public NullableDoubleStateSerializer(Type type) - { - this.type = requireNonNull(type, "type is null"); - checkArgument(type.getJavaType() == double.class, "Type must use double stack type: " + type); - } - @Override public Type getSerializedType() { - return type; + return DOUBLE; } @Override @@ -53,7 +36,7 @@ public void serialize(NullableDoubleState state, BlockBuilder out) out.appendNull(); } else { - type.writeDouble(out, state.getValue()); + DOUBLE.writeDouble(out, state.getValue()); } } @@ -65,7 +48,7 @@ public void deserialize(Block block, int index, NullableDoubleState state) } else { state.setNull(false); - state.setValue(type.getDouble(block, index)); + state.setValue(DOUBLE.getDouble(block, index)); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java index ea275305fdfb..0ca4681ef825 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/NullableLongStateSerializer.java @@ -13,37 +13,20 @@ */ package io.trino.operator.aggregation.state; -import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.type.Type; -import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BigintType.BIGINT; -import static java.util.Objects.requireNonNull; public class NullableLongStateSerializer implements AccumulatorStateSerializer { - private final Type type; - - @UsedByGeneratedCode - public NullableLongStateSerializer() - { - this(BIGINT); - } - - public NullableLongStateSerializer(Type type) - { - this.type = requireNonNull(type, "type is null"); - checkArgument(type.getJavaType() == long.class, "Type must use long stack type: " + type); - } - @Override public Type getSerializedType() { - return type; + return BIGINT; } @Override @@ -53,7 +36,7 @@ public void serialize(NullableLongState state, BlockBuilder out) out.appendNull(); } else { - type.writeLong(out, state.getValue()); + BIGINT.writeLong(out, state.getValue()); } } @@ -65,7 +48,7 @@ public void deserialize(Block block, int index, NullableLongState state) } else { state.setNull(false); - state.setValue(type.getLong(block, index)); + state.setValue(BIGINT.getLong(block, index)); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestState.java index 27224fb75f6e..33d14d446353 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestState.java @@ -15,7 +15,11 @@ import io.airlift.stats.QuantileDigest; import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; +@AccumulatorStateMetadata( + stateFactoryClass = QuantileDigestStateFactory.class, + stateSerializerClass = QuantileDigestStateSerializer.class) public interface QuantileDigestState extends AccumulatorState { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestStateSerializer.java index b0d657c1ac70..add0f699db4b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/QuantileDigestStateSerializer.java @@ -20,14 +20,16 @@ import io.trino.spi.type.QuantileDigestType; import io.trino.spi.type.Type; +import static io.trino.spi.type.BigintType.BIGINT; + public class QuantileDigestStateSerializer implements AccumulatorStateSerializer { private final QuantileDigestType type; - public QuantileDigestStateSerializer(Type elementType) + public QuantileDigestStateSerializer() { - this.type = new QuantileDigestType(elementType); + this.type = new QuantileDigestType(BIGINT); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java index bd68892ea749..72b1f109ab72 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java @@ -13,10 +13,12 @@ */ package io.trino.operator.aggregation.state; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Ordering; import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.BytecodeNode; import io.airlift.bytecode.ClassDefinition; import io.airlift.bytecode.DynamicClassLoader; import io.airlift.bytecode.FieldDefinition; @@ -41,6 +43,9 @@ import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateMetadata; import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.GroupedAccumulatorState; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InternalDataAccessor; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.gen.CallSiteBinder; @@ -58,6 +63,8 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; import static com.google.common.base.CaseFormat.LOWER_CAMEL; import static com.google.common.base.CaseFormat.UPPER_CAMEL; @@ -74,10 +81,13 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.add; import static io.airlift.bytecode.expression.BytecodeExpressions.constantBoolean; import static io.airlift.bytecode.expression.BytecodeExpressions.constantClass; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull; import static io.airlift.bytecode.expression.BytecodeExpressions.constantNumber; import static io.airlift.bytecode.expression.BytecodeExpressions.constantString; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.expression.BytecodeExpressions.defaultValue; import static io.airlift.bytecode.expression.BytecodeExpressions.equal; import static io.airlift.bytecode.expression.BytecodeExpressions.getStatic; @@ -126,34 +136,13 @@ private static Class getBigArrayType(Class type) return ObjectBigArray.class; } - public static Type getSerializedType(Class clazz) - { - return getSerializedType(clazz, ImmutableMap.of()); - } - - public static Type getSerializedType(Class clazz, Map fieldTypes) - { - AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); - if (metadata != null && metadata.stateSerializerClass() != AccumulatorStateSerializer.class) { - try { - AccumulatorStateSerializer stateSerializer = (AccumulatorStateSerializer) metadata.stateSerializerClass().getConstructor().newInstance(); - return stateSerializer.getSerializedType(); - } - catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { - throw new RuntimeException(e); - } - } - - List fields = enumerateFields(clazz, fieldTypes); - return getSerializedType(fields); - } - public static AccumulatorStateSerializer generateStateSerializer(Class clazz) { return generateStateSerializer(clazz, ImmutableMap.of()); } - public static AccumulatorStateSerializer generateStateSerializer(Class clazz, Map fieldTypes) + @VisibleForTesting + static AccumulatorStateSerializer generateStateSerializer(Class clazz, Map fieldTypes) { AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); if (metadata != null && metadata.stateSerializerClass() != AccumulatorStateSerializer.class) { @@ -217,7 +206,7 @@ private static Type getSerializedType(List fields) return UNKNOWN; } - private static AccumulatorStateMetadata getMetadataAnnotation(Class clazz) + public static AccumulatorStateMetadata getMetadataAnnotation(Class clazz) { AccumulatorStateMetadata metadata = clazz.getAnnotation(AccumulatorStateMetadata.class); if (metadata != null) { @@ -354,12 +343,334 @@ private static Method getGetter(Class clazz, StateField field) } } + public static AccumulatorStateFactory generateInOutStateFactory(Type type) + { + CallSiteBinder callSiteBinder = new CallSiteBinder(); + ClassDefinition singleStateClassDefinition = generateInOutSingleStateClass(type, callSiteBinder); + ClassDefinition groupedStateClassDefinition = generateInOutGroupedStateClass(type, callSiteBinder); + + DynamicClassLoader classLoader = new DynamicClassLoader(StateCompiler.class.getClassLoader(), callSiteBinder.getBindings()); + Class singleStateClass = defineClass(singleStateClassDefinition, InOut.class, classLoader); + Class groupedStateClass = defineClass(groupedStateClassDefinition, InOut.class, classLoader); + + return generateStateFactory(InOut.class, singleStateClass, groupedStateClass, classLoader); + } + + private static ClassDefinition generateInOutSingleStateClass(Type type, CallSiteBinder callSiteBinder) + { + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("SingleInOut"), + type(Object.class), + type(InOut.class), + type(InternalDataAccessor.class)); + + estimatedSize(definition); + + // Generate constructor + MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); + + constructor.getBody() + .append(constructor.getThis()) + .invokeConstructor(Object.class); + + // Generate fields + FieldDefinition valueField = definition.declareField(a(PRIVATE), "value", inOutGetterReturnType(type)); + Function valueGetter = scope -> scope.getThis().getField(valueField); + + Optional nullField; + Function nullGetter; + if (type.getJavaType().isPrimitive()) { + nullField = Optional.of(definition.declareField(a(PRIVATE), "valueIdNull", boolean.class)); + constructor.getBody().append(constructor.getThis().setField(nullField.get(), constantTrue())); + nullGetter = scope -> scope.getThis().getField(nullField.get()); + } + else { + nullField = Optional.empty(); + nullGetter = scope -> isNull(valueGetter.apply(scope)); + } + + constructor.getBody() + .ret(); + + inOutSingleCopy(definition, valueField, nullField); + + Function setNullGenerator = scope -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.setField(field, constantTrue()))); + bytecodeBlock.append(thisVariable.setField(valueField, defaultValue(valueField.getType()))); + return bytecodeBlock; + }; + + BiFunction setValueGenerator = (scope, value) -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.setField(field, constantFalse()))); + bytecodeBlock.append(thisVariable.setField(valueField, value)); + return bytecodeBlock; + }; + + generateInOutMethods(type, definition, valueGetter, nullGetter, setNullGenerator, setValueGenerator, callSiteBinder); + return definition; + } + + private static ClassDefinition generateInOutGroupedStateClass(Type type, CallSiteBinder callSiteBinder) + { + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("GroupedInOut"), // todo add type + type(Object.class), + type(InOut.class), + type(GroupedAccumulatorState.class), + type(InternalDataAccessor.class)); + + estimatedSize(definition); + + MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); + constructor.getBody() + .append(constructor.getThis()) + .invokeConstructor(Object.class); + + FieldDefinition groupIdField = definition.declareField(a(PRIVATE), "groupId", long.class); + + Class valueElementType = inOutGetterReturnType(type); + FieldDefinition valueField = definition.declareField(a(PRIVATE, FINAL), "value", getBigArrayType(valueElementType)); + constructor.getBody().append(constructor.getThis().setField(valueField, newInstance(valueField.getType()))); + Function valueGetter = scope -> scope.getThis().getField(valueField).invoke("get", valueElementType, scope.getThis().getField(groupIdField)); + + Optional nullField; + Function nullGetter; + if (type.getJavaType().isPrimitive()) { + nullField = Optional.of(definition.declareField(a(PRIVATE, FINAL), "valueIdNull", BooleanBigArray.class)); + constructor.getBody().append(constructor.getThis().setField(nullField.get(), newInstance(BooleanBigArray.class, constantTrue()))); + nullGetter = scope -> scope.getThis().getField(nullField.get()).invoke("get", boolean.class, scope.getThis().getField(groupIdField)); + } + else { + nullField = Optional.empty(); + nullGetter = scope -> isNull(valueGetter.apply(scope)); + } + + constructor.getBody() + .ret(); + + inOutGroupedSetGroupId(definition, groupIdField); + inOutGroupedEnsureCapacity(definition, valueField, nullField); + inOutGroupedCopy(definition, valueField, nullField); + + Function setNullGenerator = scope -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.getField(field).invoke("set", void.class, thisVariable.getField(groupIdField), constantTrue()))); + bytecodeBlock.append(thisVariable.getField(valueField).invoke("set", void.class, thisVariable.getField(groupIdField), defaultValue(valueElementType))); + return bytecodeBlock; + }; + BiFunction setValueGenerator = (scope, value) -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.getField(field).invoke("set", void.class, thisVariable.getField(groupIdField), constantFalse()))); + bytecodeBlock.append(thisVariable.getField(valueField).invoke("set", void.class, thisVariable.getField(groupIdField), value.cast(valueElementType))); + return bytecodeBlock; + }; + + generateInOutMethods(type, definition, valueGetter, nullGetter, setNullGenerator, setValueGenerator, callSiteBinder); + + return definition; + } + + private static void generateInOutMethods(Type type, + ClassDefinition definition, + Function valueGetter, + Function nullGetter, + Function setNullGenerator, + BiFunction setValueGenerator, + CallSiteBinder callSiteBinder) + { + SqlTypeBytecodeExpression sqlType = constantType(callSiteBinder, type); + + generateInOutGetType(definition, sqlType); + generateInOutIsNull(definition, nullGetter); + generateInOutGetBlockBuilder(definition, sqlType, valueGetter); + generateInOutSetBlockPosition(definition, sqlType, setNullGenerator, setValueGenerator); + generateInOutSetInOut(definition, type, setNullGenerator, setValueGenerator); + generateInOutGetValue(definition, type, valueGetter); + } + + private static void estimatedSize(ClassDefinition definition) + { + FieldDefinition instanceSize = generateInstanceSize(definition); + + // Add getter for class size + definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class)) + .getBody() + .getStaticField(instanceSize) + .retLong(); + } + + private static void inOutSingleCopy(ClassDefinition definition, FieldDefinition valueField, Optional nullField) + { + MethodDefinition copy = definition.declareMethod(a(PUBLIC), "copy", type(AccumulatorState.class)); + Variable thisVariable = copy.getThis(); + BytecodeBlock body = copy.getBody(); + + Variable copyVariable = copy.getScope().declareVariable(definition.getType(), "copy"); + body.append(copyVariable.set(newInstance(definition.getType()))); + body.append(copyVariable.setField(valueField, thisVariable.getField(valueField))); + nullField.ifPresent(field -> body.append(copyVariable.setField(field, thisVariable.getField(field)))); + body.append(copyVariable.ret()); + } + + private static void inOutGroupedSetGroupId(ClassDefinition definition, FieldDefinition groupIdField) + { + Parameter groupIdArg = arg("groupId", long.class); + MethodDefinition method = definition.declareMethod(a(PUBLIC), "setGroupId", type(void.class), groupIdArg); + method.getBody() + .append(method.getThis().setField(groupIdField, groupIdArg)) + .ret(); + } + + private static void inOutGroupedEnsureCapacity(ClassDefinition definition, FieldDefinition valueField, Optional nullField) + { + Parameter size = arg("size", long.class); + MethodDefinition method = definition.declareMethod(a(PUBLIC), "ensureCapacity", type(void.class), size); + Variable thisVariable = method.getThis(); + BytecodeBlock body = method.getBody(); + + body.append(thisVariable.getField(valueField).invoke("ensureCapacity", void.class, size)); + nullField.ifPresent(field -> body.append(thisVariable.getField(field).invoke("ensureCapacity", void.class, size))); + body.ret(); + } + + private static void inOutGroupedCopy(ClassDefinition definition, FieldDefinition valueField, Optional nullField) + { + MethodDefinition copy = definition.declareMethod(a(PUBLIC), "copy", type(AccumulatorState.class)); + Variable thisVariable = copy.getThis(); + BytecodeBlock body = copy.getBody(); + + Variable copyVariable = copy.getScope().declareVariable(definition.getType(), "copy"); + body.append(copyVariable.set(newInstance(definition.getType()))); + copyBigArray(body, thisVariable, copyVariable, valueField); + nullField.ifPresent(field -> copyBigArray(body, thisVariable, copyVariable, field)); + body.append(copyVariable.ret()); + } + + private static void copyBigArray(BytecodeBlock body, Variable source, Variable destination, FieldDefinition bigArrayField) + { + body.append(destination.getField(bigArrayField).invoke("ensureCapacity", void.class, source.getField(bigArrayField).invoke("getCapacity", long.class))); + body.append(source.getField(bigArrayField).invoke( + "copyTo", + void.class, + constantLong(0), + destination.getField(bigArrayField), + constantLong(0), + source.getField(bigArrayField).invoke("getCapacity", long.class))); + } + + private static void generateInOutGetType(ClassDefinition definition, SqlTypeBytecodeExpression sqlType) + { + definition.declareMethod(a(PUBLIC), "getType", type(Type.class)) + .getBody() + .append(sqlType.ret()); + } + + private static void generateInOutIsNull(ClassDefinition definition, Function nullGetter) + { + MethodDefinition isNullMethod = definition.declareMethod(a(PUBLIC), "isNull", type(boolean.class)); + isNullMethod.getBody().append(nullGetter.apply(isNullMethod.getScope()).ret()); + } + + private static void generateInOutGetBlockBuilder(ClassDefinition definition, SqlTypeBytecodeExpression sqlType, Function valueGetter) + { + Parameter blockBuilderArg = arg("blockBuilder", BlockBuilder.class); + MethodDefinition getBlockBuilderMethod = definition.declareMethod(a(PUBLIC), "get", type(void.class), blockBuilderArg); + Variable thisVariable = getBlockBuilderMethod.getThis(); + BytecodeBlock body = getBlockBuilderMethod.getBody(); + + body.append(new IfStatement() + .condition(thisVariable.invoke("isNull", boolean.class)) + .ifTrue(blockBuilderArg.invoke("appendNull", BlockBuilder.class).pop()) + .ifFalse(sqlType.writeValue(blockBuilderArg, valueGetter.apply(getBlockBuilderMethod.getScope())))); + body.ret(); + } + + private static void generateInOutSetBlockPosition( + ClassDefinition definition, + SqlTypeBytecodeExpression sqlType, + Function setNullGenerator, + BiFunction setValueGenerator) + { + Parameter blockArg = arg("block", Block.class); + Parameter positionArg = arg("position", int.class); + MethodDefinition setBlockBuilderMethod = definition.declareMethod(a(PUBLIC), "set", type(void.class), blockArg, positionArg); + BytecodeBlock body = setBlockBuilderMethod.getBody(); + + body.append(new IfStatement() + .condition(blockArg.invoke("isNull", boolean.class, positionArg)) + .ifTrue(setNullGenerator.apply(setBlockBuilderMethod.getScope())) + .ifFalse(setValueGenerator.apply(setBlockBuilderMethod.getScope(), sqlType.getValue(blockArg, positionArg)))); + body.ret(); + } + + private static void generateInOutSetInOut( + ClassDefinition definition, + Type type, + Function setNullGenerator, + BiFunction setValueGenerator) + { + Parameter otherState = arg("otherState", InOut.class); + MethodDefinition setter = definition.declareMethod(a(PUBLIC), "set", type(void.class), otherState); + BytecodeBlock body = setter.getBody(); + + body.append(new IfStatement() + .condition(otherState.invoke("isNull", boolean.class)) + .ifTrue(setNullGenerator.apply(setter.getScope())) + .ifFalse(setValueGenerator.apply(setter.getScope(), otherState.cast(InternalDataAccessor.class).invoke(inOutGetterName(type), inOutGetterReturnType(type))))); + body.ret(); + } + + private static void generateInOutGetValue(ClassDefinition definition, Type type, Function valueGetter) + { + MethodDefinition getter = definition.declareMethod(a(PUBLIC), inOutGetterName(type), type(inOutGetterReturnType(type))); + getter.getBody().append(valueGetter.apply(getter.getScope()).ret()); + } + + private static Class inOutGetterReturnType(Type type) + { + Class javaType = type.getJavaType(); + if (javaType.equals(boolean.class)) { + return boolean.class; + } + if (javaType.equals(long.class)) { + return long.class; + } + if (javaType.equals(double.class)) { + return double.class; + } + return Object.class; + } + + private static String inOutGetterName(Type type) + { + Class javaType = type.getJavaType(); + if (javaType.equals(boolean.class)) { + return "getBooleanValue"; + } + if (javaType.equals(long.class)) { + return "getLongValue"; + } + if (javaType.equals(double.class)) { + return "getDoubleValue"; + } + return "getObjectValue"; + } + public static AccumulatorStateFactory generateStateFactory(Class clazz) { return generateStateFactory(clazz, ImmutableMap.of()); } - public static AccumulatorStateFactory generateStateFactory(Class clazz, Map fieldTypes) + @VisibleForTesting + static AccumulatorStateFactory generateStateFactory(Class clazz, Map fieldTypes) { AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); if (metadata != null && metadata.stateFactoryClass() != AccumulatorStateFactory.class) { @@ -377,6 +688,15 @@ public static AccumulatorStateFactory generateSt Class singleStateClass = generateSingleStateClass(clazz, fieldTypes, classLoader); Class groupedStateClass = generateGroupedStateClass(clazz, fieldTypes, classLoader); + return generateStateFactory(clazz, singleStateClass, groupedStateClass, classLoader); + } + + private static AccumulatorStateFactory generateStateFactory( + Class clazz, + Class singleStateClass, + Class groupedStateClass, + DynamicClassLoader classLoader) + { ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName(clazz.getSimpleName() + "Factory"), diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java index 9ddbd3be8a9d..1d5244071810 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java @@ -40,6 +40,16 @@ public FunctionImplementationDependency(QualifiedName fullyQualifiedName, List getArgumentTypes() + { + return argumentTypes; + } + @Override public void declareDependencies(FunctionDependencyDeclarationBuilder builder) { diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java index 900caa49b3ac..07df7f14616b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java @@ -102,7 +102,7 @@ private Factory() {} public static ImplementationDependency createDependency(Annotation annotation, Set literalParameters, Class type) { if (annotation instanceof TypeParameter) { - return new TypeImplementationDependency(((TypeParameter) annotation).value()); + return new TypeImplementationDependency(parseTypeSignature(((TypeParameter) annotation).value(), literalParameters)); } if (annotation instanceof LiteralParameter) { return new LiteralImplementationDependency(((LiteralParameter) annotation).value()); diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java index ef37c5bd2c3c..e3d79441f05b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java @@ -40,6 +40,16 @@ protected ScalarImplementationDependency(InvocationConvention invocationConventi } } + public InvocationConvention getInvocationConvention() + { + return invocationConvention; + } + + public Class getType() + { + return type; + } + protected abstract FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention); @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java index 6447a910ce48..f44a70bc6ce7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java @@ -13,7 +13,6 @@ */ package io.trino.operator.annotations; -import com.google.common.collect.ImmutableSet; import io.trino.metadata.FunctionBinding; import io.trino.metadata.FunctionDependencies; import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; @@ -22,7 +21,6 @@ import java.util.Objects; import static io.trino.metadata.SignatureBinder.applyBoundVariables; -import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature; import static java.util.Objects.requireNonNull; public final class TypeImplementationDependency @@ -30,9 +28,14 @@ public final class TypeImplementationDependency { private final TypeSignature signature; - public TypeImplementationDependency(String signature) + public TypeImplementationDependency(TypeSignature signature) { - this.signature = parseTypeSignature(requireNonNull(signature, "signature is null"), ImmutableSet.of()); + this.signature = requireNonNull(signature, "signature is null"); + } + + public TypeSignature getSignature() + { + return signature; } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesScalarFunctionImplementation.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesScalarFunctionImplementation.java index 6cec79f04836..594432a16498 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesScalarFunctionImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesScalarFunctionImplementation.java @@ -32,8 +32,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; import static io.trino.spi.function.ScalarFunctionAdapter.NullAdaptationPolicy.RETURN_NULL_ON_NULL; import static java.lang.String.format; import static java.util.Comparator.comparingInt; @@ -204,11 +202,18 @@ private static int computeScore(InvocationConvention callingConvention) { int score = 0; for (InvocationArgumentConvention argument : callingConvention.getArgumentConventions()) { - if (argument == NULL_FLAG) { - score += 1; - } - else if (argument == BLOCK_POSITION) { - score += 1000; + switch (argument) { + case NULL_FLAG: + score += 1; + break; + case BLOCK_POSITION: + score += 1000; + break; + case IN_OUT: + score += 10_000; + break; + default: + break; } } return score; diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java index 45f4309ce8b7..fff0a6588d04 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java @@ -33,6 +33,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.InOut; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.function.IsNull; @@ -77,6 +78,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -262,6 +264,9 @@ private static MethodType javaMethodType(ParametricScalarImplementationChoice ch methodHandleParameterTypes.add(Block.class); methodHandleParameterTypes.add(int.class); break; + case IN_OUT: + methodHandleParameterTypes.add(InOut.class); + break; case FUNCTION: methodHandleParameterTypes.add(choice.getLambdaInterfaces().get(lambdaArgumentIndex)); lambdaArgumentIndex++; @@ -622,6 +627,9 @@ else if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) { Annotation[] parameterAnnotations = method.getParameterAnnotations()[parameterIndex + 1]; checkState(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); } + else if (parameterType.equals(InOut.class)) { + argumentConvention = IN_OUT; + } else { // USE_NULL_FLAG or RETURN_NULL_ON_NULL checkCondition(parameterType == Void.class || !Primitives.isWrapperType(parameterType), FUNCTION_IMPLEMENTATION_ERROR, "A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method [%s]", method); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java index b8fe2695a288..38a6b7512f5a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java @@ -32,6 +32,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.InOut; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.type.Type; @@ -337,6 +338,16 @@ else if (type == ConnectorSession.class) { } currentParameterIndex++; break; + case IN_OUT: + block.append(arguments.get(realParameterIndex)); + if (!functionNullability.isArgumentNullable(realParameterIndex)) { + block.append(arguments.get(realParameterIndex)); + block.invokeVirtual(InOut.class, "isNull", boolean.class); + block.putVariable(scope.getVariable("wasNull")); + block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes))); + } + currentParameterIndex++; + break; case FUNCTION: Class lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex); block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java index 9475cf135700..da2a78523bdb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java @@ -51,8 +51,8 @@ *
  * - Aggregation
  *          GROUP BY (k)
- *          F1(x)
- *          F2(x)
+ *          F1(s0, s1, ...)
+ *          F2(s0, s1, ...)
  *      - Aggregation
  *             GROUP BY (k, s0, s1, ...)
  *          - X
@@ -67,6 +67,7 @@ public class SingleDistinctAggregationToGroupBy
             .matching(SingleDistinctAggregationToGroupBy::hasSingleDistinctInput)
             .matching(SingleDistinctAggregationToGroupBy::allDistinctAggregates)
             .matching(SingleDistinctAggregationToGroupBy::noFilters)
+            .matching(SingleDistinctAggregationToGroupBy::noOrdering)
             .matching(SingleDistinctAggregationToGroupBy::noMasks);
 
     private static boolean hasSingleDistinctInput(AggregationNode aggregationNode)
@@ -89,6 +90,13 @@ private static boolean noFilters(AggregationNode aggregationNode)
                 .noneMatch(aggregation -> aggregation.getFilter().isPresent());
     }
 
+    private static boolean noOrdering(AggregationNode aggregationNode)
+    {
+        return aggregationNode.getAggregations()
+                .values().stream()
+                .noneMatch(aggregation -> aggregation.getOrderingScheme().isPresent());
+    }
+
     private static boolean noMasks(AggregationNode aggregationNode)
     {
         return aggregationNode.getAggregations()
diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java
index 333f8f2b2e3f..455e86a69634 100644
--- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java
+++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java
@@ -72,6 +72,7 @@
 import static io.trino.metadata.FunctionManager.createTestingFunctionManager;
 import static io.trino.metadata.MetadataManager.createTestMetadataManager;
 import static io.trino.operator.aggregation.AggregationFromAnnotationsParser.parseFunctionDefinitions;
+import static io.trino.operator.aggregation.AggregationFromAnnotationsParser.toAccumulatorStateDetails;
 import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX;
 import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL;
 import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
@@ -174,7 +175,7 @@ public static void output(@AggregationState NullableDoubleState state, BlockBuil
     public void testInputParameterOrderEnforced()
     {
         assertThatThrownBy(() -> parseFunctionDefinitions(InputParametersWrongOrder.class))
-                .hasMessage("Expected input function non-dependency parameters to begin with state type NullableDoubleState: " +
+                .hasMessage("Expected input function non-dependency parameters to begin with state types [NullableDoubleState]: " +
                         "public static void io.trino.operator.TestAnnotationEngineForAggregates$InputParametersWrongOrder.input(double,io.trino.operator.aggregation.state.NullableDoubleState)");
     }
 
@@ -342,7 +343,7 @@ public void testSimpleGenericAggregationFunctionParse()
         assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with two generic implementations");
         assertTrue(aggregation.getFunctionMetadata().isDeterministic());
         assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
-        assertEquals(aggregation.getStateClass(), NullableLongState.class);
+        assertEquals(aggregation.getStateDetails(), ImmutableList.of(toAccumulatorStateDetails(NullableLongState.class, ImmutableList.of())));
         ParametricImplementationsGroup implementations = aggregation.getImplementations();
         assertImplementationCount(implementations, 0, 0, 2);
         AggregationImplementation implementationDouble = implementations.getGenericImplementations().stream()
@@ -1007,7 +1008,7 @@ public void testFixedTypeParameterInjectionAggregateFunctionParse()
         assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with fixed parameter type injected");
         assertTrue(aggregation.getFunctionMetadata().isDeterministic());
         assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
-        assertEquals(aggregation.getStateClass(), NullableDoubleState.class);
+        assertEquals(aggregation.getStateDetails(), ImmutableList.of(toAccumulatorStateDetails(NullableDoubleState.class, ImmutableList.of())));
         ParametricImplementationsGroup implementations = aggregation.getImplementations();
         assertImplementationCount(implementations, 1, 0, 0);
         AggregationImplementation implementationDouble = implementations.getExactImplementations().get(expectedSignature);
@@ -1071,7 +1072,7 @@ public void testPartiallyFixedTypeParameterInjectionAggregateFunctionParse()
         assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with fixed parameter type injected");
         assertTrue(aggregation.getFunctionMetadata().isDeterministic());
         assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
-        assertEquals(aggregation.getStateClass(), NullableDoubleState.class);
+        assertEquals(aggregation.getStateDetails(), ImmutableList.of(toAccumulatorStateDetails(NullableDoubleState.class, ImmutableList.of())));
         ParametricImplementationsGroup implementations = aggregation.getImplementations();
         assertImplementationCount(implementations, 0, 0, 1);
         AggregationImplementation implementationDouble = getOnlyElement(implementations.getGenericImplementations());
diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java
index 496b6eb8d066..9f5967e9c3c2 100644
--- a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java
+++ b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java
@@ -634,7 +634,8 @@ public void testSpillerFailure()
         RowPagesBuilder rowPagesBuilder = rowPagesBuilder(false, hashChannels, types);
         List input = rowPagesBuilder
                 .addSequencePage(10, 100, 0, 100, 0)
-                .addSequencePage(10, 100, 0, 200, 0)
+                // current accumulator allows 1024 values without using revocable memory, so add enough values to cause revocable memory usage
+                .addSequencePage(2_000, 100, 0, 200, 0)
                 .addSequencePage(10, 100, 0, 300, 0)
                 .build();
 
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java
index 9a129fb259e4..92bcfb2aeed8 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java
@@ -63,10 +63,10 @@ public static void assertAggregation(TestingFunctionResolution functionResolutio
     public static BiFunction makeValidityAssertion(Object expectedValue)
     {
         if (expectedValue instanceof Double && !expectedValue.equals(Double.NaN)) {
-            return (actual, expected) -> Precision.equals((double) actual, (double) expected, 1.0e-10);
+            return (actual, expected) -> actual != null && expected != null && Precision.equals((double) actual, (double) expected, 1.0e-10);
         }
         if (expectedValue instanceof Float && !expectedValue.equals(Float.NaN)) {
-            return (actual, expected) -> Precision.equals((float) actual, (float) expected, 1.0e-10f);
+            return (actual, expected) -> actual != null && expected != null && Precision.equals((float) actual, (float) expected, 1.0e-10f);
         }
         return Objects::equals;
     }
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/BenchmarkGroupedTypedHistogram.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/BenchmarkGroupedTypedHistogram.java
index 795c8f4723c6..a0d36e17485f 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/BenchmarkGroupedTypedHistogram.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/BenchmarkGroupedTypedHistogram.java
@@ -16,7 +16,6 @@
 import com.google.common.collect.ImmutableList;
 import io.trino.metadata.TestingFunctionResolution;
 import io.trino.operator.GroupByIdBlock;
-import io.trino.operator.aggregation.histogram.Histogram;
 import io.trino.spi.Page;
 import io.trino.spi.block.Block;
 import io.trino.sql.tree.QualifiedName;
@@ -134,7 +133,7 @@ public GroupedAggregator testSharedGroupWithLargeBlocksRunner(Data data)
     private static TestingAggregationFunction getInternalAggregationFunctionVarChar()
     {
         TestingFunctionResolution functionResolution = new TestingFunctionResolution();
-        return functionResolution.getAggregateFunction(QualifiedName.of(Histogram.NAME), fromTypes(VARCHAR));
+        return functionResolution.getAggregateFunction(QualifiedName.of("histogram"), fromTypes(VARCHAR));
     }
 
     public static void main(String[] args)
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java
index cfb3d1dbc55c..beaf53e56183 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java
@@ -134,12 +134,12 @@ public void testOverflowOnOutput()
 
     private static void addToState(LongDecimalWithOverflowState state, BigInteger value)
     {
-        BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1);
-        TYPE.writeObject(blockBuilder, Int128.valueOf(value));
         if (TYPE.isShort()) {
-            DecimalSumAggregation.inputShortDecimal(state, blockBuilder.build(), 0);
+            DecimalSumAggregation.inputShortDecimal(state, Int128.valueOf(value).toLongExact());
         }
         else {
+            BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1);
+            TYPE.writeObject(blockBuilder, Int128.valueOf(value));
             DecimalSumAggregation.inputLongDecimal(state, blockBuilder.build(), 0);
         }
     }
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java
index 2b11a8f071a7..3542efdca6e5 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java
@@ -21,7 +21,6 @@
 import io.trino.operator.aggregation.groupby.AggregationTestInput;
 import io.trino.operator.aggregation.groupby.AggregationTestInputBuilder;
 import io.trino.operator.aggregation.groupby.AggregationTestOutput;
-import io.trino.operator.aggregation.histogram.Histogram;
 import io.trino.spi.block.Block;
 import io.trino.spi.block.BlockBuilder;
 import io.trino.spi.type.ArrayType;
@@ -77,28 +76,28 @@ public void testSimpleHistograms()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(VARCHAR),
                 ImmutableMap.of("a", 1L, "b", 1L, "c", 1L),
                 createStringsBlock("a", "b", "c"));
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(BIGINT),
                 ImmutableMap.of(100L, 1L, 200L, 1L, 300L, 1L),
                 createLongsBlock(100L, 200L, 300L));
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(DOUBLE),
                 ImmutableMap.of(0.1, 1L, 0.3, 1L, 0.2, 1L),
                 createDoublesBlock(0.1, 0.3, 0.2));
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(BOOLEAN),
                 ImmutableMap.of(true, 1L, false, 1L),
                 createBooleansBlock(true, false));
@@ -109,27 +108,27 @@ public void testSharedGroupBy()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(VARCHAR),
                 ImmutableMap.of("a", 1L, "b", 1L, "c", 1L),
                 createStringsBlock("a", "b", "c"));
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(BIGINT),
                 ImmutableMap.of(100L, 1L, 200L, 1L, 300L, 1L),
                 createLongsBlock(100L, 200L, 300L));
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME), fromTypes(DOUBLE),
+                QualifiedName.of("histogram"), fromTypes(DOUBLE),
                 ImmutableMap.of(0.1, 1L, 0.3, 1L, 0.2, 1L),
                 createDoublesBlock(0.1, 0.3, 0.2));
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(BOOLEAN),
                 ImmutableMap.of(true, 1L, false, 1L),
                 createBooleansBlock(true, false));
@@ -140,7 +139,7 @@ public void testDuplicateKeysValues()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(VARCHAR),
                 ImmutableMap.of("a", 2L, "b", 1L),
                 createStringsBlock("a", "b", "a"));
@@ -149,7 +148,7 @@ public void testDuplicateKeysValues()
         long timestampWithTimeZone2 = packDateTimeWithZone(new DateTime(2015, 1, 1, 0, 0, 0, 0, DATE_TIME_ZONE).getMillis(), TIME_ZONE_KEY);
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(TIMESTAMP_WITH_TIME_ZONE),
                 ImmutableMap.of(SqlTimestampWithTimeZone.newInstance(3, unpackMillisUtc(timestampWithTimeZone1), 0, unpackZoneKey(timestampWithTimeZone1)), 2L, SqlTimestampWithTimeZone.newInstance(3, unpackMillisUtc(timestampWithTimeZone2), 0, unpackZoneKey(timestampWithTimeZone2)), 1L),
                 createLongsBlock(timestampWithTimeZone1, timestampWithTimeZone1, timestampWithTimeZone2));
@@ -160,14 +159,14 @@ public void testWithNulls()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(BIGINT),
                 ImmutableMap.of(1L, 1L, 2L, 1L),
                 createLongsBlock(2L, null, 1L));
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(BIGINT),
                 null,
                 createLongsBlock((Long) null));
@@ -179,7 +178,7 @@ public void testArrayHistograms()
         ArrayType arrayType = new ArrayType(VARCHAR);
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(arrayType),
                 ImmutableMap.of(ImmutableList.of("a", "b", "c"), 1L, ImmutableList.of("d", "e", "f"), 1L, ImmutableList.of("c", "b", "a"), 1L),
                 createStringArraysBlock(ImmutableList.of(ImmutableList.of("a", "b", "c"), ImmutableList.of("d", "e", "f"), ImmutableList.of("c", "b", "a"))));
@@ -197,7 +196,7 @@ public void testMapHistograms()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(innerMapType),
                 ImmutableMap.of(ImmutableMap.of("a", "b"), 1L, ImmutableMap.of("c", "d"), 1L, ImmutableMap.of("e", "f"), 1L),
                 builder.build());
@@ -216,7 +215,7 @@ public void testRowHistograms()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(innerRowType),
                 ImmutableMap.of(ImmutableList.of(1L, 1.0), 1L, ImmutableList.of(2L, 2.0), 1L, ImmutableList.of(3L, 3.0), 1L),
                 builder.build());
@@ -227,7 +226,7 @@ public void testLargerHistograms()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(Histogram.NAME),
+                QualifiedName.of("histogram"),
                 fromTypes(VARCHAR),
                 ImmutableMap.of("a", 25L, "b", 10L, "c", 12L, "d", 1L, "e", 2L),
                 createStringsBlock("a", "b", "c", "d", "e", "e", "c", "a", "a", "a", "b", "a", "a", "a", "a", "b", "a", "a", "a", "a", "b", "a", "a", "a", "a", "b", "a", "a", "a", "a", "b", "a", "c", "c", "b", "a", "c", "c", "b", "a", "c", "c", "b", "a", "c", "c", "b", "a", "c", "c"));
@@ -392,6 +391,6 @@ private static void testSharedGroupByWithOverlappingValuesRunner(TestingAggregat
 
     private static TestingAggregationFunction getInternalDefaultVarCharAggregation()
     {
-        return FUNCTION_RESOLUTION.getAggregateFunction(QualifiedName.of(Histogram.NAME), fromTypes(VARCHAR));
+        return FUNCTION_RESOLUTION.getAggregateFunction(QualifiedName.of("histogram"), fromTypes(VARCHAR));
     }
 }
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapAggAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapAggAggregation.java
index e66a3dd478ba..11556b738b18 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapAggAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapAggAggregation.java
@@ -50,7 +50,7 @@ public void testDuplicateKeysValues()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, VARCHAR),
                 ImmutableMap.of(1.0, "a"),
                 createDoublesBlock(1.0, 1.0, 1.0),
@@ -58,7 +58,7 @@ public void testDuplicateKeysValues()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, INTEGER),
                 ImmutableMap.of(1.0, 99, 2.0, 99, 3.0, 99),
                 createDoublesBlock(1.0, 2.0, 3.0),
@@ -70,7 +70,7 @@ public void testSimpleMaps()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, VARCHAR),
                 ImmutableMap.of(1.0, "a", 2.0, "b", 3.0, "c"),
                 createDoublesBlock(1.0, 2.0, 3.0),
@@ -78,7 +78,7 @@ public void testSimpleMaps()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, INTEGER),
                 ImmutableMap.of(1.0, 3, 2.0, 2, 3.0, 1),
                 createDoublesBlock(1.0, 2.0, 3.0),
@@ -86,7 +86,7 @@ public void testSimpleMaps()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, BOOLEAN),
                 ImmutableMap.of(1.0, true, 2.0, false, 3.0, false),
                 createDoublesBlock(1.0, 2.0, 3.0),
@@ -98,7 +98,7 @@ public void testNull()
     {
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, DOUBLE),
                 ImmutableMap.of(1.0, 2.0),
                 createDoublesBlock(1.0, null, null),
@@ -106,7 +106,7 @@ public void testNull()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, DOUBLE),
                 null,
                 createDoublesBlock(null, null, null),
@@ -118,7 +118,7 @@ public void testNull()
         expected.put(3.0, null);
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, DOUBLE),
                 expected,
                 createDoublesBlock(1.0, 2.0, 3.0),
@@ -132,7 +132,7 @@ public void testDoubleArrayMap()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, arrayType),
                 ImmutableMap.of(1.0, ImmutableList.of("a", "b"),
                         2.0, ImmutableList.of("c", "d"),
@@ -153,7 +153,7 @@ public void testDoubleMapMap()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, innerMapType),
                 ImmutableMap.of(1.0, ImmutableMap.of("a", "b"),
                         2.0, ImmutableMap.of("c", "d"),
@@ -176,7 +176,7 @@ public void testDoubleRowMap()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(DOUBLE, innerRowType),
                 ImmutableMap.of(1.0, ImmutableList.of(1, 1.0),
                         2.0, ImmutableList.of(2, 2.0),
@@ -192,7 +192,7 @@ public void testArrayDoubleMap()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapAggregationFunction.NAME),
+                QualifiedName.of("map_agg"),
                 fromTypes(arrayType, DOUBLE),
                 ImmutableMap.of(
                         ImmutableList.of("a", "b"), 1.0,
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapUnionAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapUnionAggregation.java
index 12c53dd60430..1aedae3ca8cd 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapUnionAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMapUnionAggregation.java
@@ -45,7 +45,7 @@ public void testSimpleWithDuplicates()
         MapType mapType = mapType(DOUBLE, VARCHAR);
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapUnionAggregation.NAME),
+                QualifiedName.of("map_union"),
                 fromTypes(mapType),
                 ImmutableMap.of(23.0, "aaa", 33.0, "bbb", 43.0, "ccc", 53.0, "ddd", 13.0, "eee"),
                 arrayBlockOf(
@@ -56,7 +56,7 @@ public void testSimpleWithDuplicates()
         mapType = mapType(DOUBLE, BIGINT);
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapUnionAggregation.NAME), fromTypes(mapType),
+                QualifiedName.of("map_union"), fromTypes(mapType),
                 ImmutableMap.of(1.0, 99L, 2.0, 99L, 3.0, 99L, 4.0, 44L),
                 arrayBlockOf(
                         mapType,
@@ -66,7 +66,7 @@ public void testSimpleWithDuplicates()
         mapType = mapType(BOOLEAN, BIGINT);
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapUnionAggregation.NAME),
+                QualifiedName.of("map_union"),
                 fromTypes(mapType),
                 ImmutableMap.of(false, 12L, true, 13L),
                 arrayBlockOf(
@@ -84,7 +84,7 @@ public void testSimpleWithNulls()
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapUnionAggregation.NAME),
+                QualifiedName.of("map_union"),
                 fromTypes(mapType),
                 expected,
                 arrayBlockOf(
@@ -100,7 +100,7 @@ public void testStructural()
         MapType mapType = mapType(DOUBLE, new ArrayType(VARCHAR));
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapUnionAggregation.NAME),
+                QualifiedName.of("map_union"),
                 fromTypes(mapType),
                 ImmutableMap.of(
                         1.0, ImmutableList.of("a", "b"),
@@ -133,7 +133,7 @@ public void testStructural()
         mapType = mapType(DOUBLE, mapType(VARCHAR, VARCHAR));
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapUnionAggregation.NAME),
+                QualifiedName.of("map_union"),
                 fromTypes(mapType),
                 ImmutableMap.of(
                         1.0, ImmutableMap.of("a", "b"),
@@ -159,7 +159,7 @@ public void testStructural()
         mapType = mapType(new ArrayType(VARCHAR), DOUBLE);
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MapUnionAggregation.NAME),
+                QualifiedName.of("map_union"),
                 fromTypes(mapType),
                 ImmutableMap.of(
                         ImmutableList.of("a", "b"), 1.0,
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxby/TestMinMaxByAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMinMaxByAggregation.java
similarity index 99%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxby/TestMinMaxByAggregation.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/TestMinMaxByAggregation.java
index 4874d0e6063f..2d4e284e532d 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxby/TestMinMaxByAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMinMaxByAggregation.java
@@ -11,7 +11,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation.minmaxby;
+package io.trino.operator.aggregation;
 
 import com.google.common.collect.ImmutableList;
 import io.trino.FeaturesConfig;
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java
index 0a6cc2d2d623..65b0a13c5906 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java
@@ -22,7 +22,6 @@
 import io.trino.operator.aggregation.groupby.AggregationTestInput;
 import io.trino.operator.aggregation.groupby.AggregationTestInputBuilder;
 import io.trino.operator.aggregation.groupby.AggregationTestOutput;
-import io.trino.operator.aggregation.multimapagg.MultimapAggregationFunction;
 import io.trino.spi.Page;
 import io.trino.spi.block.Block;
 import io.trino.spi.block.BlockBuilder;
@@ -181,7 +180,7 @@ public void testEmptyStateOutputIsNull()
 
     private static TestingAggregationFunction getAggregationFunction(Type keyType, Type valueType)
     {
-        return FUNCTION_RESOLUTION.getAggregateFunction(QualifiedName.of(MultimapAggregationFunction.NAME), fromTypes(keyType, valueType));
+        return FUNCTION_RESOLUTION.getAggregateFunction(QualifiedName.of("multimap_agg"), fromTypes(keyType, valueType));
     }
 
     /**
@@ -207,7 +206,7 @@ private static  void testMultimapAgg(Type keyType, List expectedKeys, T
 
         assertAggregation(
                 FUNCTION_RESOLUTION,
-                QualifiedName.of(MultimapAggregationFunction.NAME),
+                QualifiedName.of("multimap_agg"),
                 fromTypes(keyType, valueType),
                 map.isEmpty() ? null : map,
                 builder.build());
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java
index a61f56e45695..dd87d06c7821 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java
@@ -15,18 +15,27 @@
 
 import io.airlift.slice.Slice;
 import io.trino.block.BlockAssertions;
+import io.trino.metadata.TestingFunctionResolution;
 import io.trino.spi.TrinoException;
 import io.trino.spi.block.Block;
 import io.trino.spi.block.BlockBuilder;
 import io.trino.spi.block.VariableWidthBlockBuilder;
+import io.trino.sql.analyzer.TypeSignatureProvider;
+import io.trino.sql.tree.QualifiedName;
 import org.testcontainers.shaded.org.apache.commons.lang.StringUtils;
 import org.testng.annotations.Test;
 
+import java.util.List;
+
 import static io.airlift.slice.Slices.utf8Slice;
+import static io.trino.block.BlockAssertions.createBooleansBlock;
 import static io.trino.block.BlockAssertions.createStringsBlock;
+import static io.trino.operator.aggregation.AggregationTestUtils.assertAggregation;
 import static io.trino.spi.StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT;
 import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
+import static io.trino.spi.type.BooleanType.BOOLEAN;
 import static io.trino.spi.type.VarcharType.VARCHAR;
+import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertFalse;
@@ -34,16 +43,18 @@
 
 public class TestListaggAggregationFunction
 {
+    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();
+
     @Test
     public void testInputEmptyState()
     {
-        SingleListaggAggregationState state = new SingleListaggAggregationState(VARCHAR);
+        SingleListaggAggregationState state = new SingleListaggAggregationState();
 
         String s = "value1";
         Block value = createStringsBlock(s);
         Slice separator = utf8Slice(",");
         Slice overflowFiller = utf8Slice("...");
-        ListaggAggregationFunction.input(VARCHAR,
+        ListaggAggregationFunction.input(
                 state,
                 value,
                 separator,
@@ -73,9 +84,9 @@ public void testInputOverflowOverflowFillerTooLong()
     {
         String overflowFillerTooLong = StringUtils.repeat(".", 65_537);
 
-        SingleListaggAggregationState state = new SingleListaggAggregationState(VARCHAR);
+        SingleListaggAggregationState state = new SingleListaggAggregationState();
 
-        assertThatThrownBy(() -> ListaggAggregationFunction.input(VARCHAR,
+        assertThatThrownBy(() -> ListaggAggregationFunction.input(
                 state,
                 createStringsBlock("value1"),
                 utf8Slice(","),
@@ -206,6 +217,32 @@ public void testOutputTruncatedStateWithIndicationCountComplexSeparator()
         assertEquals(getOutputStateOnlyValue(state, 21), "a###b###c###dd###e###...(7)");
     }
 
+    @Test
+    public void testExecute()
+    {
+        List parameterTypes = fromTypes(VARCHAR, VARCHAR, BOOLEAN, VARCHAR, BOOLEAN);
+        assertAggregation(
+                FUNCTION_RESOLUTION,
+                QualifiedName.of("listagg"),
+                parameterTypes,
+                null,
+                createStringsBlock(null, null, null),
+                createStringsBlock(",", ",", ","),
+                createBooleansBlock(false, false, false),
+                createStringsBlock("", "", ""),
+                createBooleansBlock(false, false, false));
+        assertAggregation(
+                FUNCTION_RESOLUTION,
+                QualifiedName.of("listagg"),
+                parameterTypes,
+                "a,c",
+                createStringsBlock("a", null, "c"),
+                createStringsBlock(",", ",", ","),
+                createBooleansBlock(false, false, false),
+                createStringsBlock("", "", ""),
+                createBooleansBlock(false, false, false));
+    }
+
     private static String getOutputStateOnlyValue(SingleListaggAggregationState state, int maxOutputLengthInBytes)
     {
         BlockBuilder out = new VariableWidthBlockBuilder(null, 32, 256);
@@ -215,7 +252,7 @@ private static String getOutputStateOnlyValue(SingleListaggAggregationState stat
 
     private static SingleListaggAggregationState createListaggAggregationState(String separator, boolean overflowError, String overflowFiller, boolean showOverflowEntryCount, String... values)
     {
-        SingleListaggAggregationState state = new SingleListaggAggregationState(VARCHAR);
+        SingleListaggAggregationState state = new SingleListaggAggregationState();
         state.setSeparator(utf8Slice(separator));
         state.setOverflowError(overflowError);
         state.setOverflowFiller(utf8Slice(overflowFiller));
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestMinMaxByNAggregation.java
similarity index 99%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestMinMaxByNAggregation.java
index 77fcb905086c..561769c41642 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestMinMaxByNAggregation.java
@@ -11,7 +11,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation.minmaxby;
+package io.trino.operator.aggregation.minmaxbyn;
 
 import com.google.common.collect.ImmutableList;
 import io.trino.metadata.TestingFunctionResolution;
@@ -590,7 +590,7 @@ public void testOutOfBound()
             groupedAggregation(function, new Page(createStringsBlock("z"), createLongsBlock(0), createLongsBlock(10001)));
         }
         catch (TrinoException e) {
-            assertEquals(e.getMessage(), "third argument of max_by/min_by must be less than or equal to 10000; found 10001");
+            assertEquals(e.getMessage(), "third argument of max_by must be less than or equal to 10000; found 10001");
         }
     }
 }
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedKeyValueHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java
similarity index 88%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedKeyValueHeap.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java
index a1318a76bd6d..bfef6e07ca53 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedKeyValueHeap.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java
@@ -11,7 +11,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation;
+package io.trino.operator.aggregation.minmaxbyn;
 
 import io.trino.spi.block.Block;
 import io.trino.spi.block.BlockBuilder;
@@ -31,7 +31,6 @@
 import static io.trino.spi.function.InvocationConvention.simpleConvention;
 import static io.trino.spi.type.BigintType.BIGINT;
 import static io.trino.spi.type.VarcharType.VARCHAR;
-import static io.trino.util.MinMaxCompare.getMinMaxCompare;
 import static org.testng.Assert.assertEquals;
 
 public class TestTypedKeyValueHeap
@@ -40,18 +39,20 @@ public class TestTypedKeyValueHeap
     private static final int OUTPUT_SIZE = 1_000;
 
     private static final TypeOperators TYPE_OPERATOR_FACTORY = new TypeOperators();
-    private static final MethodHandle MAX_ELEMENTS_COMPARATOR = getMinMaxCompare(TYPE_OPERATOR_FACTORY, BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), false);
-    private static final MethodHandle MIN_ELEMENTS_COMPARATOR = getMinMaxCompare(TYPE_OPERATOR_FACTORY, BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), true);
+    private static final MethodHandle MAX_ELEMENTS_COMPARATOR = TYPE_OPERATOR_FACTORY.getComparisonUnorderedFirstOperator(BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION));
+    private static final MethodHandle MIN_ELEMENTS_COMPARATOR = TYPE_OPERATOR_FACTORY.getComparisonUnorderedLastOperator(BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION));
 
     @Test
     public void testAscending()
     {
         test(IntStream.range(0, INPUT_SIZE),
                 IntStream.range(0, INPUT_SIZE).mapToObj(key -> Integer.toString(key * 2)),
+                false,
                 MAX_ELEMENTS_COMPARATOR,
                 IntStream.range(INPUT_SIZE - OUTPUT_SIZE, INPUT_SIZE).mapToObj(key -> Integer.toString(key * 2)).iterator());
         test(IntStream.range(0, INPUT_SIZE),
                 IntStream.range(0, INPUT_SIZE).mapToObj(key -> Integer.toString(key * 2)),
+                true,
                 MIN_ELEMENTS_COMPARATOR,
                 IntStream.range(0, OUTPUT_SIZE).map(x -> OUTPUT_SIZE - 1 - x).mapToObj(key -> Integer.toString(key * 2)).iterator());
     }
@@ -61,10 +62,12 @@ public void testDescending()
     {
         test(IntStream.range(0, INPUT_SIZE).map(x -> INPUT_SIZE - 1 - x),
                 IntStream.range(0, INPUT_SIZE).map(x -> INPUT_SIZE - 1 - x).mapToObj(key -> Integer.toString(key * 2)),
+                false,
                 MAX_ELEMENTS_COMPARATOR,
                 IntStream.range(INPUT_SIZE - OUTPUT_SIZE, INPUT_SIZE).mapToObj(key -> Integer.toString(key * 2)).iterator());
         test(IntStream.range(0, INPUT_SIZE).map(x -> INPUT_SIZE - 1 - x),
                 IntStream.range(0, INPUT_SIZE).map(x -> INPUT_SIZE - 1 - x).mapToObj(key -> Integer.toString(key * 2)),
+                true,
                 MIN_ELEMENTS_COMPARATOR,
                 IntStream.range(0, OUTPUT_SIZE).map(x -> OUTPUT_SIZE - 1 - x).mapToObj(key -> Integer.toString(key * 2)).iterator());
     }
@@ -76,22 +79,24 @@ public void testShuffled()
         Collections.shuffle(list);
         test(list.stream().mapToInt(Integer::intValue),
                 list.stream().mapToInt(Integer::intValue).mapToObj(key -> Integer.toString(key * 2)),
+                false,
                 MAX_ELEMENTS_COMPARATOR,
                 IntStream.range(INPUT_SIZE - OUTPUT_SIZE, INPUT_SIZE).mapToObj(key -> Integer.toString(key * 2)).iterator());
         test(list.stream().mapToInt(Integer::intValue),
                 list.stream().mapToInt(Integer::intValue).mapToObj(key -> Integer.toString(key * 2)),
+                true,
                 MIN_ELEMENTS_COMPARATOR,
                 IntStream.range(0, OUTPUT_SIZE).map(x -> OUTPUT_SIZE - 1 - x).mapToObj(key -> Integer.toString(key * 2)).iterator());
     }
 
-    private static void test(IntStream keyInputStream, Stream valueInputStream, MethodHandle keyComparisonMethod, Iterator outputIterator)
+    private static void test(IntStream keyInputStream, Stream valueInputStream, boolean min, MethodHandle comparisonMethod, Iterator outputIterator)
     {
         BlockBuilder keysBlockBuilder = BIGINT.createBlockBuilder(null, INPUT_SIZE);
         BlockBuilder valuesBlockBuilder = VARCHAR.createBlockBuilder(null, INPUT_SIZE);
         keyInputStream.forEach(x -> BIGINT.writeLong(keysBlockBuilder, x));
         valueInputStream.forEach(x -> VARCHAR.writeString(valuesBlockBuilder, x));
 
-        TypedKeyValueHeap heap = new TypedKeyValueHeap(keyComparisonMethod, BIGINT, VARCHAR, OUTPUT_SIZE);
+        TypedKeyValueHeap heap = new TypedKeyValueHeap(min, comparisonMethod, BIGINT, VARCHAR, OUTPUT_SIZE);
         heap.addAll(keysBlockBuilder, valuesBlockBuilder);
 
         BlockBuilder resultBlockBuilder = VARCHAR.createBlockBuilder(null, OUTPUT_SIZE);
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestArrayMaxNAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestArrayMaxNAggregation.java
similarity index 97%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/TestArrayMaxNAggregation.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestArrayMaxNAggregation.java
index 7a9e7c860da8..5fea0777d28b 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestArrayMaxNAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestArrayMaxNAggregation.java
@@ -11,10 +11,11 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation;
+package io.trino.operator.aggregation.minmaxn;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
+import io.trino.operator.aggregation.AbstractTestAggregationFunction;
 import io.trino.spi.block.Block;
 import io.trino.spi.block.BlockBuilder;
 import io.trino.spi.type.ArrayType;
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleMinNAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestDoubleMinNAggregation.java
similarity index 96%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleMinNAggregation.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestDoubleMinNAggregation.java
index d0ea9f6d4bb6..590af88cb809 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleMinNAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestDoubleMinNAggregation.java
@@ -11,9 +11,10 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation;
+package io.trino.operator.aggregation.minmaxn;
 
 import com.google.common.collect.ImmutableList;
+import io.trino.operator.aggregation.AbstractTestAggregationFunction;
 import io.trino.spi.block.Block;
 import io.trino.spi.type.Type;
 import org.testng.annotations.Test;
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestLongMaxNAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestLongMaxNAggregation.java
similarity index 96%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/TestLongMaxNAggregation.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestLongMaxNAggregation.java
index 6cd03c23c279..31fe823d3693 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestLongMaxNAggregation.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestLongMaxNAggregation.java
@@ -11,9 +11,10 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation;
+package io.trino.operator.aggregation.minmaxn;
 
 import com.google.common.collect.ImmutableList;
+import io.trino.operator.aggregation.AbstractTestAggregationFunction;
 import io.trino.spi.block.Block;
 import io.trino.spi.type.Type;
 import org.testng.annotations.Test;
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java
similarity index 79%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHeap.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java
index 9ea1fc6350f1..6e74e4a70108 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHeap.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java
@@ -11,10 +11,11 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation;
+package io.trino.operator.aggregation.minmaxn;
 
 import io.trino.spi.block.Block;
 import io.trino.spi.block.BlockBuilder;
+import io.trino.spi.type.Type;
 import io.trino.spi.type.TypeOperators;
 import org.testng.annotations.Test;
 
@@ -29,7 +30,6 @@
 import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
 import static io.trino.spi.function.InvocationConvention.simpleConvention;
 import static io.trino.spi.type.BigintType.BIGINT;
-import static io.trino.util.MinMaxCompare.getMinMaxCompare;
 import static org.testng.Assert.assertEquals;
 
 public class TestTypedHeap
@@ -38,17 +38,21 @@ public class TestTypedHeap
     private static final int OUTPUT_SIZE = 1_000;
 
     private static final TypeOperators TYPE_OPERATOR_FACTORY = new TypeOperators();
-    private static final MethodHandle MAX_ELEMENTS_COMPARATOR = getMinMaxCompare(TYPE_OPERATOR_FACTORY, BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), false);
-    private static final MethodHandle MIN_ELEMENTS_COMPARATOR = getMinMaxCompare(TYPE_OPERATOR_FACTORY, BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), true);
+    private static final MethodHandle MAX_ELEMENTS_COMPARATOR = TYPE_OPERATOR_FACTORY.getComparisonUnorderedFirstOperator(BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION));
+    private static final MethodHandle MIN_ELEMENTS_COMPARATOR = TYPE_OPERATOR_FACTORY.getComparisonUnorderedLastOperator(BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION));
 
     @Test
     public void testAscending()
     {
         test(IntStream.range(0, INPUT_SIZE),
+                false,
                 MAX_ELEMENTS_COMPARATOR,
+                BIGINT,
                 IntStream.range(INPUT_SIZE - OUTPUT_SIZE, INPUT_SIZE).iterator());
         test(IntStream.range(0, INPUT_SIZE),
+                true,
                 MIN_ELEMENTS_COMPARATOR,
+                BIGINT,
                 IntStream.range(0, OUTPUT_SIZE).map(x -> OUTPUT_SIZE - 1 - x).iterator());
     }
 
@@ -56,10 +60,14 @@ public void testAscending()
     public void testDescending()
     {
         test(IntStream.range(0, INPUT_SIZE).map(x -> INPUT_SIZE - 1 - x),
+                false,
                 MAX_ELEMENTS_COMPARATOR,
+                BIGINT,
                 IntStream.range(INPUT_SIZE - OUTPUT_SIZE, INPUT_SIZE).iterator());
         test(IntStream.range(0, INPUT_SIZE).map(x -> INPUT_SIZE - 1 - x),
+                true,
                 MIN_ELEMENTS_COMPARATOR,
+                BIGINT,
                 IntStream.range(0, OUTPUT_SIZE).map(x -> OUTPUT_SIZE - 1 - x).iterator());
     }
 
@@ -69,19 +77,23 @@ public void testShuffled()
         List list = IntStream.range(0, INPUT_SIZE).collect(ArrayList::new, ArrayList::add, ArrayList::addAll);
         Collections.shuffle(list);
         test(list.stream().mapToInt(Integer::intValue),
+                false,
                 MAX_ELEMENTS_COMPARATOR,
+                BIGINT,
                 IntStream.range(INPUT_SIZE - OUTPUT_SIZE, INPUT_SIZE).iterator());
         test(list.stream().mapToInt(Integer::intValue),
+                true,
                 MIN_ELEMENTS_COMPARATOR,
+                BIGINT,
                 IntStream.range(0, OUTPUT_SIZE).map(x -> OUTPUT_SIZE - 1 - x).iterator());
     }
 
-    private static void test(IntStream inputStream, MethodHandle greaterThanMethod, PrimitiveIterator.OfInt outputIterator)
+    private static void test(IntStream inputStream, boolean min, MethodHandle comparisonMethod, Type elementType, PrimitiveIterator.OfInt outputIterator)
     {
         BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, INPUT_SIZE);
-        inputStream.forEach(x -> BIGINT.writeLong(blockBuilder, x));
+        inputStream.forEach(value -> BIGINT.writeLong(blockBuilder, value));
 
-        TypedHeap heap = new TypedHeap(greaterThanMethod, BIGINT, OUTPUT_SIZE);
+        TypedHeap heap = new TypedHeap(min, comparisonMethod, elementType, OUTPUT_SIZE);
         heap.addAll(blockBuilder);
 
         BlockBuilder resultBlockBuilder = BIGINT.createBlockBuilder(null, OUTPUT_SIZE);
diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestStateCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestStateCompiler.java
similarity index 97%
rename from core/trino-main/src/test/java/io/trino/operator/aggregation/TestStateCompiler.java
rename to core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestStateCompiler.java
index f685ae9f535a..28e633ae5376 100644
--- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestStateCompiler.java
+++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestStateCompiler.java
@@ -11,7 +11,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation;
+package io.trino.operator.aggregation.state;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
@@ -24,10 +24,6 @@
 import io.trino.array.LongBigArray;
 import io.trino.array.ReferenceCountMap;
 import io.trino.array.SliceBigArray;
-import io.trino.operator.aggregation.state.LongState;
-import io.trino.operator.aggregation.state.NullableLongState;
-import io.trino.operator.aggregation.state.StateCompiler;
-import io.trino.operator.aggregation.state.VarianceState;
 import io.trino.spi.block.Block;
 import io.trino.spi.block.BlockBuilder;
 import io.trino.spi.function.AccumulatorState;
@@ -58,6 +54,8 @@
 import static io.trino.util.StructuralTestUtil.mapBlockOf;
 import static io.trino.util.StructuralTestUtil.mapType;
 import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertTrue;
 
 public class TestStateCompiler
 {
@@ -79,12 +77,12 @@ public void testPrimitiveNullableLongSerialization()
 
         Block block = builder.build();
 
-        assertEquals(block.isNull(0), false);
+        assertFalse(block.isNull(0));
         assertEquals(BIGINT.getLong(block, 0), state.getValue());
         serializer.deserialize(block, 0, deserializedState);
         assertEquals(deserializedState.getValue(), state.getValue());
 
-        assertEquals(block.isNull(1), true);
+        assertTrue(block.isNull(1));
     }
 
     @Test
@@ -239,7 +237,7 @@ public void testComplexSerialization()
         assertEquals(deserializedState.getAnotherBlock().getSlice(1, 0, 9), singleState.getAnotherBlock().getSlice(1, 0, 9));
     }
 
-    private long getComplexStateRetainedSize(TestComplexState state)
+    private static long getComplexStateRetainedSize(TestComplexState state)
     {
         long retainedSize = ClassLayout.parseClass(state.getClass()).instanceSize();
         // reflection is necessary because TestComplexState implementation is generated
diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestLambdaExpression.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestLambdaExpression.java
index e8306a4aabaf..3853e4125132 100644
--- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestLambdaExpression.java
+++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestLambdaExpression.java
@@ -172,8 +172,8 @@ public void testTypeCombinations()
     @Test
     public void testFunctionParameter()
     {
-        assertInvalidFunction("count(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters () for function count. Expected: count(), count(T) T");
-        assertInvalidFunction("max(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters () for function max. Expected: max(E) E:orderable, max(E, bigint) E:orderable");
+        assertInvalidFunction("count(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters () for function count. Expected: count(), count(t) T");
+        assertInvalidFunction("max(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters () for function max. Expected: max(t) T:orderable, max(e, bigint) E:orderable");
         assertInvalidFunction("sqrt(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters () for function sqrt. Expected: sqrt(double)");
         assertInvalidFunction("sqrt(x -> x, 123, x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (, integer, ) for function sqrt. Expected: sqrt(double)");
         assertInvalidFunction("pow(x -> x, 123)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (, integer) for function pow. Expected: pow(double, double)");
diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/AccumulatorStateMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/function/AccumulatorStateMetadata.java
index 62bc32786e81..5675a79877d0 100644
--- a/core/trino-spi/src/main/java/io/trino/spi/function/AccumulatorStateMetadata.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/function/AccumulatorStateMetadata.java
@@ -25,4 +25,8 @@
     Class stateSerializerClass() default AccumulatorStateSerializer.class;
 
     Class stateFactoryClass() default AccumulatorStateFactory.class;
+
+    String[] typeParameters() default {};
+
+    String serializedType() default "";
 }
diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/AggregationState.java b/core/trino-spi/src/main/java/io/trino/spi/function/AggregationState.java
index d71aca0fb2e7..f08caedfab26 100644
--- a/core/trino-spi/src/main/java/io/trino/spi/function/AggregationState.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/function/AggregationState.java
@@ -24,5 +24,5 @@
 @Target({METHOD, PARAMETER})
 public @interface AggregationState
 {
-    String value() default "";
+    String[] value() default {};
 }
diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/InOut.java b/core/trino-spi/src/main/java/io/trino/spi/function/InOut.java
new file mode 100644
index 000000000000..3ea8a5ec1599
--- /dev/null
+++ b/core/trino-spi/src/main/java/io/trino/spi/function/InOut.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.spi.function;
+
+import io.trino.spi.block.Block;
+import io.trino.spi.block.BlockBuilder;
+import io.trino.spi.type.Type;
+
+@AccumulatorStateMetadata(typeParameters = "T", serializedType = "T")
+public interface InOut
+        extends AccumulatorState
+{
+    Type getType();
+
+    boolean isNull();
+
+    void get(BlockBuilder blockBuilder);
+
+    void set(Block block, int position);
+
+    void set(InOut otherState);
+}
diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByNAggregationFunction.java b/core/trino-spi/src/main/java/io/trino/spi/function/InternalDataAccessor.java
similarity index 50%
rename from core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByNAggregationFunction.java
rename to core/trino-spi/src/main/java/io/trino/spi/function/InternalDataAccessor.java
index c43cde39cbac..f041354a6e6a 100644
--- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByNAggregationFunction.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/function/InternalDataAccessor.java
@@ -11,17 +11,31 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package io.trino.operator.aggregation.minmaxby;
+package io.trino.spi.function;
 
-import io.trino.type.BlockTypeOperators;
-
-public class MinByNAggregationFunction
-        extends AbstractMinMaxByNAggregationFunction
+/**
+ * This interface is an internal detail of the SPI and should never be used.  It will be changed
+ * without notification.
+ */
+public interface InternalDataAccessor
 {
-    private static final String NAME = "min_by";
+    default boolean getBooleanValue()
+    {
+        throw new UnsupportedOperationException();
+    }
+
+    default double getDoubleValue()
+    {
+        throw new UnsupportedOperationException();
+    }
+
+    default long getLongValue()
+    {
+        throw new UnsupportedOperationException();
+    }
 
-    public MinByNAggregationFunction(BlockTypeOperators blockTypeOperators)
+    default Object getObjectValue()
     {
-        super(NAME, true, "Returns the values of the first argument associated with the minimum values of the second argument");
+        throw new UnsupportedOperationException();
     }
 }
diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java
index d0fd1faaa2a9..a11e03f158a0 100644
--- a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java
@@ -123,6 +123,10 @@ public enum InvocationArgumentConvention
          * sql value may be null.
          */
         BLOCK_POSITION(true, 2),
+        /**
+         * Argument is passed in an InOut. The sql value may be null.
+         */
+        IN_OUT(true, 1),
         /**
          * Argument is a lambda function.
          */
diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java
index b439db8af291..ff87e9b1358d 100644
--- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java
@@ -31,6 +31,7 @@
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION;
+import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
 import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
@@ -140,8 +141,8 @@ private boolean canAdaptParameter(
             return true;
         }
 
-        // no conversions from function or block and position are supported
-        if (actualArgumentConvention == BLOCK_POSITION || actualArgumentConvention == FUNCTION) {
+        // no conversions to block and position, function, or in-out are supported
+        if (actualArgumentConvention == BLOCK_POSITION || actualArgumentConvention == FUNCTION || actualArgumentConvention == IN_OUT) {
             return false;
         }
 
@@ -150,8 +151,8 @@ private boolean canAdaptParameter(
             return true;
         }
 
-        // nulls are passed in blocks, so adapter will handle null or throw exception at runtime
-        if (expectedArgumentConvention == BLOCK_POSITION) {
+        // nulls are passed in blocks or in-out values, so adapter will handle null or throw exception at runtime
+        if (expectedArgumentConvention == BLOCK_POSITION || expectedArgumentConvention == IN_OUT) {
             return true;
         }
 
@@ -289,10 +290,13 @@ private MethodHandle adaptParameter(
             return methodHandle;
         }
         if (actualArgumentConvention == BLOCK_POSITION) {
-            throw new IllegalArgumentException("Block and position argument can not be adapted");
+            throw new IllegalArgumentException("Block and position argument cannot be adapted");
+        }
+        if (actualArgumentConvention == IN_OUT) {
+            throw new IllegalArgumentException("In-out argument cannot be adapted");
         }
         if (actualArgumentConvention == FUNCTION) {
-            throw new IllegalArgumentException("Function argument can not be adapted");
+            throw new IllegalArgumentException("Function argument cannot be adapted");
         }
 
         // caller will never pass null
@@ -426,7 +430,7 @@ private MethodHandle adaptParameter(
             throw new IllegalArgumentException("Unsupported actual argument convention: " + actualArgumentConvention);
         }
 
-        // caller will pass boolean true in the next argument for SQL null
+        // caller passes block and position which may contain a null
         if (expectedArgumentConvention == BLOCK_POSITION) {
             MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex));
 
@@ -489,6 +493,70 @@ private MethodHandle adaptParameter(
             throw new IllegalArgumentException("Unsupported actual argument convention: " + actualArgumentConvention);
         }
 
+        // caller passes in-out which may contain a null
+        if (expectedArgumentConvention == IN_OUT) {
+            MethodHandle getInOutValue = getInOutValue(argumentType, methodHandle.type().parameterType(parameterIndex));
+
+            if (actualArgumentConvention == NEVER_NULL) {
+                if (nullAdaptationPolicy == UNDEFINED_VALUE_FOR_NULL) {
+                    // Current, null is not checked, so whatever value returned is passed through
+                    methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue);
+                    return methodHandle;
+                }
+
+                if (nullAdaptationPolicy == RETURN_NULL_ON_NULL && returnConvention != FAIL_ON_NULL) {
+                    // if caller sets null flag, return null, otherwise invoke target
+                    methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue);
+                    return guardWithTest(
+                            isInOutNull(methodHandle.type(), parameterIndex),
+                            returnNull(methodHandle.type()),
+                            methodHandle);
+                }
+
+                if (nullAdaptationPolicy == THROW_ON_NULL || nullAdaptationPolicy == UNSUPPORTED || nullAdaptationPolicy == RETURN_NULL_ON_NULL) {
+                    MethodHandle adapter = guardWithTest(
+                            isInOutNull(getInOutValue.type(), 0),
+                            throwTrinoNullArgumentException(getInOutValue.type()),
+                            getInOutValue);
+
+                    return collectArguments(methodHandle, parameterIndex, adapter);
+                }
+            }
+
+            if (actualArgumentConvention == BOXED_NULLABLE) {
+                getInOutValue = explicitCastArguments(getInOutValue, getInOutValue.type().changeReturnType(wrap(getInOutValue.type().returnType())));
+                getInOutValue = guardWithTest(
+                        isInOutNull(getInOutValue.type(), 0),
+                        returnNull(getInOutValue.type()),
+                        getInOutValue);
+                methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue);
+                return methodHandle;
+            }
+
+            if (actualArgumentConvention == NULL_FLAG) {
+                // long, boolean => long, InOut
+                MethodHandle isNull = isInOutNull(getInOutValue.type(), 0);
+                methodHandle = collectArguments(methodHandle, parameterIndex + 1, isNull);
+
+                // long, InOut => InOut, InOut
+                getInOutValue = guardWithTest(
+                        isInOutNull(getInOutValue.type(), 0),
+                        returnNull(getInOutValue.type()),
+                        getInOutValue);
+                methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue);
+
+                // InOut, InOut => InOut
+                int[] reorder = IntStream.range(0, methodHandle.type().parameterCount())
+                        .map(i -> i <= parameterIndex ? i : i - 1)
+                        .toArray();
+                MethodType newType = methodHandle.type().dropParameterTypes(parameterIndex + 1, parameterIndex + 2);
+                methodHandle = permuteArguments(methodHandle, newType, reorder);
+                return methodHandle;
+            }
+
+            throw new IllegalArgumentException("Unsupported actual argument convention: " + actualArgumentConvention);
+        }
+
         throw new IllegalArgumentException("Unsupported expected argument convention: " + expectedArgumentConvention);
     }
 
@@ -523,6 +591,33 @@ else if (methodArgumentType == Slice.class) {
         }
     }
 
+    private static MethodHandle getInOutValue(Type argumentType, Class expectedType)
+    {
+        Class methodArgumentType = argumentType.getJavaType();
+        String getterName;
+        if (methodArgumentType == boolean.class) {
+            getterName = "getBooleanValue";
+        }
+        else if (methodArgumentType == long.class) {
+            getterName = "getLongValue";
+        }
+        else if (methodArgumentType == double.class) {
+            getterName = "getDoubleValue";
+        }
+        else {
+            getterName = "getObjectValue";
+            methodArgumentType = Object.class;
+        }
+
+        try {
+            MethodHandle getValue = lookup().findVirtual(InternalDataAccessor.class, getterName, methodType(methodArgumentType));
+            return explicitCastArguments(getValue, methodType(expectedType, InOut.class));
+        }
+        catch (ReflectiveOperationException e) {
+            throw new AssertionError(e);
+        }
+    }
+
     private static MethodHandle boxedToNullFlagFilter(Class argumentType)
     {
         // Start with identity
@@ -571,6 +666,19 @@ private static MethodHandle isBlockPositionNull(MethodType methodType, int index
         return isNull;
     }
 
+    private static MethodHandle isInOutNull(MethodType methodType, int index)
+    {
+        MethodHandle isNull;
+        try {
+            isNull = lookup().findVirtual(InOut.class, "isNull", methodType(boolean.class));
+        }
+        catch (ReflectiveOperationException e) {
+            throw new AssertionError(e);
+        }
+        isNull = permuteArguments(isNull, methodType.changeReturnType(boolean.class), index);
+        return isNull;
+    }
+
     private static MethodHandle lookupIsNullMethod()
     {
         MethodHandle isNull;
diff --git a/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java b/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java
index 771cdd486020..d940dd4a778e 100644
--- a/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java
+++ b/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java
@@ -63,6 +63,8 @@ public class TestSpiBackwardCompatibility
                     "Method: public io.trino.spi.ptf.ReturnTypeSpecification io.trino.spi.ptf.ConnectorTableFunction.getReturnTypeSpecification()",
                     "Method: public java.lang.String io.trino.spi.ptf.ConnectorTableFunction.getName()",
                     "Method: public java.lang.String io.trino.spi.ptf.ConnectorTableFunction.getSchema()"))
+            .put("383", ImmutableSet.of(
+                    "Method: public abstract java.lang.String io.trino.spi.function.AggregationState.value()"))
             .buildOrThrow();
 
     @Test
diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java
index 26f7990bad2b..188d4ba373c0 100644
--- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java
+++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java
@@ -39,6 +39,7 @@
 import java.util.List;
 import java.util.stream.IntStream;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkState;
 import static com.google.common.base.Verify.verify;
 import static com.google.common.collect.ImmutableList.toImmutableList;
@@ -46,6 +47,7 @@
 import static io.trino.spi.block.TestingSession.SESSION;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
+import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
 import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
 import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
@@ -63,6 +65,7 @@
 import static java.lang.invoke.MethodHandles.lookup;
 import static java.lang.invoke.MethodType.methodType;
 import static java.util.Collections.nCopies;
+import static java.util.Objects.requireNonNull;
 import static org.assertj.core.api.Fail.fail;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertFalse;
@@ -181,7 +184,7 @@ private static void verifyAllAdaptations(
             throws Throwable
     {
         List> allArgumentConventions = allCombinations(
-                ImmutableList.of(NEVER_NULL, BOXED_NULLABLE, NULL_FLAG, BLOCK_POSITION),
+                ImmutableList.of(NEVER_NULL, BOXED_NULLABLE, NULL_FLAG, BLOCK_POSITION, IN_OUT),
                 argumentTypes.size());
         for (List argumentConventions : allArgumentConventions) {
             for (InvocationReturnConvention returnConvention : InvocationReturnConvention.values()) {
@@ -273,6 +276,10 @@ private static boolean hasNullableToNoNullableAdaptation(InvocationConvention ac
                 // this conversion is not allowed
                 return true;
             }
+            if (actualArgumentConvention == IN_OUT) {
+                // this conversion is not allowed
+                return true;
+            }
         }
         return false;
     }
@@ -290,7 +297,9 @@ private static boolean canCallConventionWithNullArguments(InvocationConvention c
     private static boolean hasNullBlockAndPositionToNeverNullArgument(InvocationConvention actualConvention, InvocationConvention expectedConvention, BitSet nullArguments)
     {
         for (int i = 0; i < actualConvention.getArgumentConventions().size(); i++) {
-            if (nullArguments.get(i) && actualConvention.getArgumentConvention(i) == NEVER_NULL && expectedConvention.getArgumentConvention(i) == BLOCK_POSITION) {
+            InvocationArgumentConvention argumentConvention = actualConvention.getArgumentConvention(i);
+            InvocationArgumentConvention expectedArgumentConvention = expectedConvention.getArgumentConvention(i);
+            if (nullArguments.get(i) && argumentConvention == NEVER_NULL && (expectedArgumentConvention == BLOCK_POSITION || expectedArgumentConvention == IN_OUT)) {
                 return true;
             }
         }
@@ -322,6 +331,9 @@ private static List> toCallArgumentTypes(InvocationConvention callingCo
                     expectedArguments.add(Block.class);
                     expectedArguments.add(int.class);
                     break;
+                case IN_OUT:
+                    expectedArguments.add(InOut.class);
+                    break;
                 default:
                     throw new IllegalArgumentException("Unsupported argument convention: " + argumentConvention);
             }
@@ -367,6 +379,9 @@ private static List toCallArgumentValues(InvocationConvention callingCon
                     callArguments.add(blockBuilder.build());
                     callArguments.add(1);
                     break;
+                case IN_OUT:
+                    callArguments.add(new TestingInOut(argumentType, testValue));
+                    break;
                 default:
                     throw new IllegalArgumentException("Unsupported argument convention: " + argumentConvention);
             }
@@ -699,4 +714,164 @@ private static void assertBlockEquals(Type type, Block actual, Block expected)
             }
         }
     }
+
+    private static class TestingInOut
+            implements InOut, InternalDataAccessor
+    {
+        private final Type type;
+        private Object value;
+
+        public TestingInOut(Type type, Object value)
+        {
+            this.type = requireNonNull(type, "type is null");
+            this.value = value;
+
+            if (value != null) {
+                Class javaType = type.getJavaType();
+                if (javaType.equals(boolean.class)) {
+                    checkArgument(value instanceof Boolean, "Value must be a Boolean for type %s", type);
+                }
+                else if (javaType.equals(long.class)) {
+                    checkArgument(value instanceof Long, "Value must be a Long for type %s", type);
+                }
+                else if (javaType.equals(double.class)) {
+                    checkArgument(value instanceof Double, "Value must be a Double for type %s", type);
+                }
+            }
+        }
+
+        @Override
+        public AccumulatorState copy()
+        {
+            return new TestingInOut(type, value);
+        }
+
+        @Override
+        public long getEstimatedSize()
+        {
+            return 0;
+        }
+
+        @Override
+        public Type getType()
+        {
+            return type;
+        }
+
+        @Override
+        public final boolean isNull()
+        {
+            return value == null;
+        }
+
+        @Override
+        public final void get(BlockBuilder blockBuilder)
+        {
+            Class javaType = type.getJavaType();
+
+            Object value = this.value;
+            if (value == null) {
+                blockBuilder.appendNull();
+            }
+            else if (javaType.equals(boolean.class)) {
+                type.writeBoolean(blockBuilder, (Boolean) value);
+            }
+            else if (javaType.equals(long.class)) {
+                type.writeLong(blockBuilder, (Long) value);
+            }
+            else if (javaType.equals(double.class)) {
+                type.writeDouble(blockBuilder, (Double) value);
+            }
+            else if (javaType.equals(Slice.class)) {
+                type.writeSlice(blockBuilder, (Slice) value);
+            }
+            else {
+                type.writeObject(blockBuilder, value);
+            }
+        }
+
+        @Override
+        public final void set(Block block, int position)
+        {
+            Class javaType = type.getJavaType();
+
+            Object value;
+            if (block.isNull(position)) {
+                value = null;
+            }
+            else if (javaType.equals(boolean.class)) {
+                value = type.getBoolean(block, position);
+            }
+            else if (javaType.equals(long.class)) {
+                value = type.getLong(block, position);
+            }
+            else if (javaType.equals(double.class)) {
+                value = type.getDouble(block, position);
+            }
+            else if (javaType.equals(Slice.class)) {
+                value = type.getSlice(block, position);
+            }
+            else {
+                value = type.getObject(block, position);
+            }
+            this.value = value;
+        }
+
+        @Override
+        public final void set(InOut otherState)
+        {
+            checkArgument(type.equals(otherState.getType()), "Expected other state to be type %s, but is type %s", type, otherState.getType());
+
+            Class javaType = type.getJavaType();
+            Object value;
+            if (otherState.isNull()) {
+                value = null;
+            }
+            else if (javaType.equals(boolean.class)) {
+                value = ((InternalDataAccessor) otherState).getBooleanValue();
+            }
+            else if (javaType.equals(long.class)) {
+                value = ((InternalDataAccessor) otherState).getLongValue();
+            }
+            else if (javaType.equals(double.class)) {
+                value = ((InternalDataAccessor) otherState).getDoubleValue();
+            }
+            else {
+                value = ((InternalDataAccessor) otherState).getObjectValue();
+            }
+            this.value = value;
+        }
+
+        @Override
+        public final boolean getBooleanValue()
+        {
+            checkArgument(type.getJavaType().equals(boolean.class), "Type %s does not have a boolean stack type", type);
+            Object value = this.value;
+            return value != null && (Boolean) value;
+        }
+
+        @Override
+        public final double getDoubleValue()
+        {
+            checkArgument(type.getJavaType().equals(double.class), "Type %s does not have a double stack type", type);
+            Object value = this.value;
+            return value == null ? 0.0 : (Double) value;
+        }
+
+        @Override
+        public final long getLongValue()
+        {
+            checkArgument(type.getJavaType().equals(long.class), "Type %s does not have a long stack type", type);
+            Object value = this.value;
+            return value == null ? 0L : (Long) value;
+        }
+
+        @Override
+        public final Object getObjectValue()
+        {
+            checkArgument(!type.getJavaType().isPrimitive(), "Type %s does not have an Object stack type", type);
+            Object value = this.value;
+            return value;
+        }
+    }
 }
diff --git a/testing/trino-product-tests/src/main/resources/sql-tests/testcases/map_functions/checkMapFunctionsRegistered.result b/testing/trino-product-tests/src/main/resources/sql-tests/testcases/map_functions/checkMapFunctionsRegistered.result
index 53ad4bab8963..79aa49d325a3 100644
--- a/testing/trino-product-tests/src/main/resources/sql-tests/testcases/map_functions/checkMapFunctionsRegistered.result
+++ b/testing/trino-product-tests/src/main/resources/sql-tests/testcases/map_functions/checkMapFunctionsRegistered.result
@@ -1,6 +1,6 @@
 -- delimiter: |; ignoreOrder: true; ignoreExcessRows: true; trimValues:true
  cardinality | bigint | map(k,v) | scalar | true | Returns the cardinality (the number of key-value pairs) of the map |
  map | map(K,V) | array(K), array(V) | scalar | true | Constructs a map from the given key/value arrays |
- map_agg | map(K,V) | K, V | aggregate | true | Aggregates all the rows (key/value pairs) into a single map |
+ map_agg | map(k,v) | k, v | aggregate | true | Aggregates all the rows (key/value pairs) into a single map |
  map_keys | array(k) | map(k,v) | scalar | true | Returns the keys of the given map(K,V) as an array |
  map_values | array(v) | map(k,v) | scalar | true | Returns the values of the given map(K,V) as an array |
diff --git a/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result b/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result
index 7366c5b42ec0..f635924115bf 100644
--- a/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result
+++ b/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result
@@ -10,7 +10,7 @@
  approx_percentile | double | double, double | aggregate | true | |
  approx_set | hyperloglog | bigint | aggregate | true | |
  approx_set | hyperloglog | double | aggregate | true | |
- arbitrary | T | T | aggregate | true | Return an arbitrary non-null input value |
+ arbitrary | t | t | aggregate | true | Return an arbitrary non-null input value |
  asin | double | double | scalar | true | Arc sine |
  atan | double | double | scalar | true | Arc tangent |
  atan2 | double | double, double | scalar | true | Arc tangent of given fraction |
@@ -27,7 +27,7 @@
  cos | double | double | scalar | true | Cosine |
  cosh | double | double | scalar | true | Hyperbolic cosine |
  count | bigint | | aggregate | true | |
- count | bigint | T | aggregate | true | Counts the non-null values |
+ count | bigint | t | aggregate | true | Counts the non-null values |
  count_if | bigint | boolean | aggregate | true | |
  covar_pop | double | double, double | aggregate | true | |
  covar_samp | double | double, double | aggregate | true | |
@@ -46,10 +46,10 @@
  ln | double | double | scalar | true | Natural logarithm |
  log10 | double | double | scalar | true | Logarithm to base 10 |
  log2 | double | double | scalar | true | Logarithm to base 2 |
- max | E | E | aggregate | true | Returns the maximum value of the argument |
- max_by | V | V, K | aggregate | true | Returns the value of the first argument, associated with the maximum value of the second argument |
- min | E | E | aggregate | true | Returns the minimum value of the argument |
- min_by | V | V, K | aggregate | true | Returns the value of the first argument, associated with the minimum value of the second argument |
+ max | t | t | aggregate | true | Returns the maximum value of the argument |
+ max_by | v | v, k | aggregate | true | Returns the value of the first argument, associated with the maximum value of the second argument |
+ min | t | t | aggregate | true | Returns the minimum value of the argument |
+ min_by | v | v, k | aggregate | true | Returns the value of the first argument, associated with the minimum value of the second argument |
  mod | bigint | bigint, bigint | scalar | true | Remainder of given quotient |
  mod | double | double, double | scalar | true | Remainder of given quotient |
  nan | double | | scalar | true | Constant representing not-a-number |