From ba00941f345e80e778df7e3809ce6b91255b181c Mon Sep 17 00:00:00 2001 From: Aleksander Kramarz Date: Wed, 24 Jan 2024 19:03:34 +0100 Subject: [PATCH 1/3] Add test case reproducing the issue --- ...MultiArgumentLongTimestampAggregation.java | 41 ++++++++++++++++ .../aggregation/TestAccumulatorCompiler.java | 49 ++++++++++++------- 2 files changed, 72 insertions(+), 18 deletions(-) create mode 100644 core/trino-main/src/test/java/io/trino/operator/aggregation/MultiArgumentLongTimestampAggregation.java diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/MultiArgumentLongTimestampAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/MultiArgumentLongTimestampAggregation.java new file mode 100644 index 000000000000..426ae4a6e340 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/MultiArgumentLongTimestampAggregation.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.LongTimestamp; + +import static io.trino.spi.type.BigintType.BIGINT; + +public final class MultiArgumentLongTimestampAggregation +{ + private MultiArgumentLongTimestampAggregation() {} + + public static void input(LongTimestampAggregationState state, + LongTimestamp arg1, LongTimestamp arg2, LongTimestamp arg3, LongTimestamp arg4, + LongTimestamp arg5, LongTimestamp arg6, LongTimestamp arg7, LongTimestamp arg8) + { + state.setValue(state.getValue() + 1); + } + + public static void combine(LongTimestampAggregationState stateA, LongTimestampAggregationState stateB) + { + stateA.setValue(stateA.getValue() + stateB.getValue()); + } + + public static void output(LongTimestampAggregationState state, BlockBuilder blockBuilder) + { + BIGINT.writeLong(blockBuilder, state.getValue()); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index 0d2561864be8..4891a775e0a0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -14,6 +14,7 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import io.airlift.bytecode.DynamicClassLoader; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -25,12 +26,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.LongArrayBlockBuilder; -import io.trino.spi.function.AccumulatorState; -import io.trino.spi.function.AccumulatorStateFactory; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.function.AggregationImplementation; -import io.trino.spi.function.BoundSignature; -import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.*; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.RealType; import io.trino.spi.type.TimestampType; @@ -39,10 +35,13 @@ import java.lang.invoke.MethodHandle; import java.lang.reflect.Constructor; +import java.util.Arrays; +import java.util.Collections; import java.util.Optional; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory; +import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; @@ -56,15 +55,22 @@ public class TestAccumulatorCompiler @Test public void testAccumulatorCompilerForTypeSpecificObjectParameter() { - testAccumulatorCompilerForTypeSpecificObjectParameter(true); - testAccumulatorCompilerForTypeSpecificObjectParameter(false); + testAccumulatorCompilerForTypeSpecificObjectParameter(LongTimestampAggregation.class, true); + testAccumulatorCompilerForTypeSpecificObjectParameter(LongTimestampAggregation.class, false); } - private void testAccumulatorCompilerForTypeSpecificObjectParameter(boolean specializedLoops) + @Test + public void testAccumulatorCompilerForTypeSpecificObjectParameterMultipleInputArgs() + { + testAccumulatorCompilerForTypeSpecificObjectParameter(MultiArgumentLongTimestampAggregation.class, true); + testAccumulatorCompilerForTypeSpecificObjectParameter(MultiArgumentLongTimestampAggregation.class, false); + } + + private void testAccumulatorCompilerForTypeSpecificObjectParameter(Class aggregation, boolean specializedLoops) { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); - assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class, specializedLoops); + assertGenerateAccumulator(aggregation, LongTimestampAggregationState.class, specializedLoops); } @Test @@ -102,12 +108,19 @@ private static void assertGenerateAccumulator(Cl AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(stateInterface); AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateInterface); + Class[] inputArgTypes = Arrays.stream(aggregation.getMethods()) + .filter(m -> m.getName().equals("input")).findFirst().get() + .getParameterTypes(); + int inputArgCount = inputArgTypes.length - 1; + BoundSignature signature = new BoundSignature( builtinFunctionName("longTimestampAggregation"), RealType.REAL, - ImmutableList.of(TIMESTAMP_PICOS)); - MethodHandle inputFunction = methodHandle(aggregation, "input", stateInterface, LongTimestamp.class); - inputFunction = normalizeInputMethod(inputFunction, signature, STATE, INPUT_CHANNEL); + Collections.nCopies(inputArgCount, TIMESTAMP_PICOS)); + MethodHandle inputFunction = methodHandle(aggregation, "input", inputArgTypes); + inputFunction = normalizeInputMethod( + inputFunction, signature, + Lists.asList(STATE, Collections.nCopies(inputArgCount, INPUT_CHANNEL).toArray(AggregationParameterKind[]::new))); MethodHandle combineFunction = methodHandle(aggregation, "combine", stateInterface, stateInterface); MethodHandle outputFunction = methodHandle(aggregation, "output", stateInterface, BlockBuilder.class); AggregationImplementation implementation = AggregationImplementation.builder() @@ -116,7 +129,7 @@ private static void assertGenerateAccumulator(Cl .outputFunction(outputFunction) .accumulatorStateDescriptor(stateInterface, stateSerializer, stateFactory) .build(); - FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false)); + FunctionNullability functionNullability = new FunctionNullability(false, Collections.nCopies(inputArgCount, false)); // test if we can compile aggregation AccumulatorFactory accumulatorFactory = generateAccumulatorFactory(signature, implementation, functionNullability, specializedLoops); @@ -137,17 +150,17 @@ private static void assertGenerateAccumulator(Cl windowAccumulator.evaluateFinal(new LongArrayBlockBuilder(null, 1)); TestingAggregationFunction aggregationFunction = new TestingAggregationFunction( - ImmutableList.of(TIMESTAMP_PICOS), + Collections.nCopies(inputArgCount, TIMESTAMP_PICOS), ImmutableList.of(BIGINT), BIGINT, accumulatorFactory); - assertThat(AggregationTestUtils.aggregation(aggregationFunction, createPage(1234))).isEqualTo(1234L); + assertThat(AggregationTestUtils.aggregation(aggregationFunction, createPage(1234, inputArgCount))).isEqualTo(1234L); } - private static Page createPage(int count) + private static Page createPage(int count, int repeat) { Block timestampSequenceBlock = createTimestampSequenceBlock(count); - return new Page(timestampSequenceBlock.getPositionCount(), timestampSequenceBlock); + return new Page(timestampSequenceBlock.getPositionCount(), Collections.nCopies(repeat, timestampSequenceBlock).toArray(Block[]::new)); } private static Block createTimestampSequenceBlock(int count) From ac42eb888731c2691379f299bce75a0f54a7638a Mon Sep 17 00:00:00 2001 From: Aleksander Kramarz Date: Thu, 18 Jan 2024 16:30:03 +0100 Subject: [PATCH 2/3] Allow for specialized loops only up to 6 parameters --- .../io/trino/operator/aggregation/AccumulatorCompiler.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 1a5e6a312bbf..bc6ca160d359 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -85,6 +85,8 @@ public final class AccumulatorCompiler { + private static final int MAX_ARGS_FOR_SPECIALIZED_LOOP = 6; + private AccumulatorCompiler() {} public static AccumulatorFactory generateAccumulatorFactory( @@ -536,11 +538,12 @@ private static BytecodeBlock generateInputForLoop( CallSiteBinder callSiteBinder, boolean grouped) { - if (specializedLoops) { + int parameterCount = parameterVariables.size(); + if (specializedLoops && parameterCount <= MAX_ARGS_FOR_SPECIALIZED_LOOP) { BytecodeBlock newBlock = new BytecodeBlock(); Variable thisVariable = scope.getThis(); - MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterVariables.size(), grouped); + MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterCount, grouped); ImmutableList.Builder parameters = ImmutableList.builder(); parameters.add(mask); From 538ddd00dee72b1169e215d97020069a7094babf Mon Sep 17 00:00:00 2001 From: Aleksander Kramarz Date: Thu, 25 Jan 2024 10:15:48 +0100 Subject: [PATCH 3/3] Fix style --- .../operator/aggregation/TestAccumulatorCompiler.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index 4891a775e0a0..43807c040c55 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -26,7 +26,12 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.LongArrayBlockBuilder; -import io.trino.spi.function.*; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.RealType; import io.trino.spi.type.TimestampType;