Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ public static Constructor<? extends WindowAccumulator> generateWindowAccumulator
AggregationImplementation implementation,
FunctionNullability functionNullability)
{
// change types used in Aggregation methods to types used in the core Trino engine to simplify code generation
implementation = normalizeAggregationMethods(implementation);

DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader());

List<Boolean> argumentNullable = functionNullability.getArgumentNullable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.operator.window.InternalWindowIndex;
import io.trino.server.PluginManager;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.Int96ArrayBlock;
import io.trino.spi.block.LongArrayBlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
Expand All @@ -33,6 +38,8 @@
import org.testng.annotations.Test;

import java.lang.invoke.MethodHandle;
import java.lang.reflect.Constructor;
import java.util.Optional;

import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
Expand Down Expand Up @@ -96,7 +103,20 @@ private static <S extends AccumulatorState, A> void assertGenerateAccumulator(Cl
// test if we can compile aggregation
AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, implementation, functionNullability);
assertThat(accumulatorFactory).isNotNull();
assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, implementation, functionNullability)).isNotNull();

// compile window aggregation
Constructor<? extends WindowAccumulator> actual = AccumulatorCompiler.generateWindowAccumulatorClass(signature, implementation, functionNullability);
assertThat(actual).isNotNull();
WindowAccumulator windowAccumulator;
try {
windowAccumulator = actual.newInstance(ImmutableList.of());
}
catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
// call the functions to ensure that the code does not reference the wrong state
windowAccumulator.addInput(new TestWindowIndex(), 0, 5);
windowAccumulator.evaluateFinal(new LongArrayBlockBuilder(null, 1));

TestingAggregationFunction aggregationFunction = new TestingAggregationFunction(
ImmutableList.of(TIMESTAMP_PICOS),
Expand All @@ -120,4 +140,74 @@ private static Block createTimestampSequenceBlock(int count)
}
return builder.build();
}

private static class TestWindowIndex
implements InternalWindowIndex
{
@Override
public int size()
{
return 10;
}

@Override
public boolean isNull(int channel, int position)
{
return false;
}

@Override
public boolean getBoolean(int channel, int position)
{
return false;
}

@Override
public long getLong(int channel, int position)
{
return 0;
}

@Override
public double getDouble(int channel, int position)
{
return 0;
}

@Override
public Slice getSlice(int channel, int position)
{
return Slices.EMPTY_SLICE;
}

@Override
public Block getSingleValueBlock(int channel, int position)
{
return null;
}

@Override
public Object getObject(int channel, int position)
{
return null;
}

@Override
public void appendTo(int channel, int position, BlockBuilder output)
{
output.appendNull();
}

@Override
public Block getRawBlock(int channel, int position)
{
return new Int96ArrayBlock(1, Optional.empty(), new long[] {0}, new int[] {0});
}

@Override
public int getRawBlockPosition(int position)
{
return 0;
}
}
}