Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@

public final class AccumulatorCompiler
{
private static final int MAX_ARGS_FOR_SPECIALIZED_LOOP = 6;

private AccumulatorCompiler() {}

public static AccumulatorFactory generateAccumulatorFactory(
Expand Down Expand Up @@ -536,11 +538,12 @@ private static BytecodeBlock generateInputForLoop(
CallSiteBinder callSiteBinder,
boolean grouped)
{
if (specializedLoops) {
int parameterCount = parameterVariables.size();
if (specializedLoops && parameterCount <= MAX_ARGS_FOR_SPECIALIZED_LOOP) {
BytecodeBlock newBlock = new BytecodeBlock();
Variable thisVariable = scope.getThis();

MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterVariables.size(), grouped);
MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterCount, grouped);

ImmutableList.Builder<BytecodeExpression> parameters = ImmutableList.builder();
parameters.add(mask);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation;

import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.LongTimestamp;

import static io.trino.spi.type.BigintType.BIGINT;

public final class MultiArgumentLongTimestampAggregation
{
private MultiArgumentLongTimestampAggregation() {}

public static void input(LongTimestampAggregationState state,
LongTimestamp arg1, LongTimestamp arg2, LongTimestamp arg3, LongTimestamp arg4,
LongTimestamp arg5, LongTimestamp arg6, LongTimestamp arg7, LongTimestamp arg8)
{
state.setValue(state.getValue() + 1);
}

public static void combine(LongTimestampAggregationState stateA, LongTimestampAggregationState stateB)
{
stateA.setValue(stateA.getValue() + stateB.getValue());
}

public static void output(LongTimestampAggregationState state, BlockBuilder blockBuilder)
{
BIGINT.writeLong(blockBuilder, state.getValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
Expand All @@ -39,10 +40,13 @@

import java.lang.invoke.MethodHandle;
import java.lang.reflect.Constructor;
import java.util.Arrays;
import java.util.Collections;
import java.util.Optional;

import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod;
Expand All @@ -56,15 +60,22 @@ public class TestAccumulatorCompiler
@Test
public void testAccumulatorCompilerForTypeSpecificObjectParameter()
{
testAccumulatorCompilerForTypeSpecificObjectParameter(true);
testAccumulatorCompilerForTypeSpecificObjectParameter(false);
testAccumulatorCompilerForTypeSpecificObjectParameter(LongTimestampAggregation.class, true);
testAccumulatorCompilerForTypeSpecificObjectParameter(LongTimestampAggregation.class, false);
}

private void testAccumulatorCompilerForTypeSpecificObjectParameter(boolean specializedLoops)
@Test
public void testAccumulatorCompilerForTypeSpecificObjectParameterMultipleInputArgs()
{
testAccumulatorCompilerForTypeSpecificObjectParameter(MultiArgumentLongTimestampAggregation.class, true);
testAccumulatorCompilerForTypeSpecificObjectParameter(MultiArgumentLongTimestampAggregation.class, false);
}

private <A> void testAccumulatorCompilerForTypeSpecificObjectParameter(Class<A> aggregation, boolean specializedLoops)
{
TimestampType parameterType = TimestampType.TIMESTAMP_NANOS;
assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class);
assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class, specializedLoops);
assertGenerateAccumulator(aggregation, LongTimestampAggregationState.class, specializedLoops);
}

@Test
Expand Down Expand Up @@ -102,12 +113,19 @@ private static <S extends AccumulatorState, A> void assertGenerateAccumulator(Cl
AccumulatorStateSerializer<S> stateSerializer = StateCompiler.generateStateSerializer(stateInterface);
AccumulatorStateFactory<S> stateFactory = StateCompiler.generateStateFactory(stateInterface);

Class<?>[] inputArgTypes = Arrays.stream(aggregation.getMethods())
.filter(m -> m.getName().equals("input")).findFirst().get()
.getParameterTypes();
int inputArgCount = inputArgTypes.length - 1;

BoundSignature signature = new BoundSignature(
builtinFunctionName("longTimestampAggregation"),
RealType.REAL,
ImmutableList.of(TIMESTAMP_PICOS));
MethodHandle inputFunction = methodHandle(aggregation, "input", stateInterface, LongTimestamp.class);
inputFunction = normalizeInputMethod(inputFunction, signature, STATE, INPUT_CHANNEL);
Collections.nCopies(inputArgCount, TIMESTAMP_PICOS));
MethodHandle inputFunction = methodHandle(aggregation, "input", inputArgTypes);
inputFunction = normalizeInputMethod(
inputFunction, signature,
Lists.asList(STATE, Collections.nCopies(inputArgCount, INPUT_CHANNEL).toArray(AggregationParameterKind[]::new)));
MethodHandle combineFunction = methodHandle(aggregation, "combine", stateInterface, stateInterface);
MethodHandle outputFunction = methodHandle(aggregation, "output", stateInterface, BlockBuilder.class);
AggregationImplementation implementation = AggregationImplementation.builder()
Expand All @@ -116,7 +134,7 @@ private static <S extends AccumulatorState, A> void assertGenerateAccumulator(Cl
.outputFunction(outputFunction)
.accumulatorStateDescriptor(stateInterface, stateSerializer, stateFactory)
.build();
FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false));
FunctionNullability functionNullability = new FunctionNullability(false, Collections.nCopies(inputArgCount, false));

// test if we can compile aggregation
AccumulatorFactory accumulatorFactory = generateAccumulatorFactory(signature, implementation, functionNullability, specializedLoops);
Expand All @@ -137,17 +155,17 @@ private static <S extends AccumulatorState, A> void assertGenerateAccumulator(Cl
windowAccumulator.evaluateFinal(new LongArrayBlockBuilder(null, 1));

TestingAggregationFunction aggregationFunction = new TestingAggregationFunction(
ImmutableList.of(TIMESTAMP_PICOS),
Collections.nCopies(inputArgCount, TIMESTAMP_PICOS),
ImmutableList.of(BIGINT),
BIGINT,
accumulatorFactory);
assertThat(AggregationTestUtils.aggregation(aggregationFunction, createPage(1234))).isEqualTo(1234L);
assertThat(AggregationTestUtils.aggregation(aggregationFunction, createPage(1234, inputArgCount))).isEqualTo(1234L);
}

private static Page createPage(int count)
private static Page createPage(int count, int repeat)
{
Block timestampSequenceBlock = createTimestampSequenceBlock(count);
return new Page(timestampSequenceBlock.getPositionCount(), timestampSequenceBlock);
return new Page(timestampSequenceBlock.getPositionCount(), Collections.nCopies(repeat, timestampSequenceBlock).toArray(Block[]::new));
}

private static Block createTimestampSequenceBlock(int count)
Expand Down