diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index dd1f56c53a3c..24edcb571d79 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -219,6 +219,9 @@ public static Constructor generateWindowAccumulator AggregationImplementation implementation, FunctionNullability functionNullability) { + // change types used in Aggregation methods to types used in the core Trino engine to simplify code generation + implementation = normalizeAggregationMethods(implementation); + DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader()); List argumentNullable = functionNullability.getArgumentNullable() diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index a84db0a52e2a..0722e170e7c5 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -15,11 +15,16 @@ import com.google.common.collect.ImmutableList; import io.airlift.bytecode.DynamicClassLoader; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.operator.aggregation.state.StateCompiler; +import io.trino.operator.window.InternalWindowIndex; import io.trino.server.PluginManager; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int96ArrayBlock; +import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateSerializer; @@ -33,6 +38,8 @@ import org.testng.annotations.Test; import java.lang.invoke.MethodHandle; +import java.lang.reflect.Constructor; +import java.util.Optional; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; @@ -96,7 +103,20 @@ private static void assertGenerateAccumulator(Cl // test if we can compile aggregation AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, implementation, functionNullability); assertThat(accumulatorFactory).isNotNull(); - assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, implementation, functionNullability)).isNotNull(); + + // compile window aggregation + Constructor actual = AccumulatorCompiler.generateWindowAccumulatorClass(signature, implementation, functionNullability); + assertThat(actual).isNotNull(); + WindowAccumulator windowAccumulator; + try { + windowAccumulator = actual.newInstance(ImmutableList.of()); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + // call the functions to ensure that the code does not reference the wrong state + windowAccumulator.addInput(new TestWindowIndex(), 0, 5); + windowAccumulator.evaluateFinal(new LongArrayBlockBuilder(null, 1)); TestingAggregationFunction aggregationFunction = new TestingAggregationFunction( ImmutableList.of(TIMESTAMP_PICOS), @@ -120,4 +140,74 @@ private static Block createTimestampSequenceBlock(int count) } return builder.build(); } + + private static class TestWindowIndex + implements InternalWindowIndex + { + @Override + public int size() + { + return 10; + } + + @Override + public boolean isNull(int channel, int position) + { + return false; + } + + @Override + public boolean getBoolean(int channel, int position) + { + return false; + } + + @Override + public long getLong(int channel, int position) + { + return 0; + } + + @Override + public double getDouble(int channel, int position) + { + return 0; + } + + @Override + public Slice getSlice(int channel, int position) + { + return Slices.EMPTY_SLICE; + } + + @Override + public Block getSingleValueBlock(int channel, int position) + { + return null; + } + + @Override + public Object getObject(int channel, int position) + { + return null; + } + + @Override + public void appendTo(int channel, int position, BlockBuilder output) + { + output.appendNull(); + } + + @Override + public Block getRawBlock(int channel, int position) + { + return new Int96ArrayBlock(1, Optional.empty(), new long[] {0}, new int[] {0}); + } + + @Override + public int getRawBlockPosition(int position) + { + return 0; + } + } }