Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;

import java.util.Optional;

public interface Accumulator
{
long getEstimatedSize();

Accumulator copy();

void addInput(Page arguments, Optional<Block> mask);
void addInput(Page arguments, AggregationMask mask);

void addIntermediate(Block block);

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ public interface AccumulatorFactory
GroupedAccumulator createGroupedAccumulator(List<Supplier<Object>> lambdaProviders);

GroupedAccumulator createGroupedIntermediateAccumulator(List<Supplier<Object>> lambdaProviders);

AggregationMaskBuilder createAggregationMaskBuilder();
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -34,7 +32,6 @@
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
import static java.lang.invoke.MethodHandles.collectArguments;
import static java.lang.invoke.MethodHandles.lookup;
import static java.lang.invoke.MethodHandles.permuteArguments;
import static java.util.Objects.requireNonNull;

public final class AggregationFunctionAdapter
Expand Down Expand Up @@ -103,7 +100,6 @@ public static MethodHandle normalizeInputMethod(
List<AggregationParameterKind> inputArgumentKinds = parameterKinds.stream()
.filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL)
.collect(toImmutableList());
boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL);

checkArgument(
boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(),
Expand All @@ -113,19 +109,21 @@ public static MethodHandle normalizeInputMethod(

List<AggregationParameterKind> expectedInputArgumentKinds = new ArrayList<>();
expectedInputArgumentKinds.addAll(stateArgumentKinds);
expectedInputArgumentKinds.addAll(inputArgumentKinds);
if (hasInputChannel) {
expectedInputArgumentKinds.add(BLOCK_INDEX);
for (AggregationParameterKind kind : inputArgumentKinds) {
expectedInputArgumentKinds.add(kind);
if (kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL) {
expectedInputArgumentKinds.add(BLOCK_INDEX);
}
}

checkArgument(
expectedInputArgumentKinds.equals(parameterKinds),
"Expected input parameter kinds %s, but got %s",
expectedInputArgumentKinds,
parameterKinds);

MethodType inputMethodType = inputMethod.type();
for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) {
int parameterIndex = stateArgumentKinds.size() + argumentIndex;
int parameterIndex = stateArgumentKinds.size() + (argumentIndex * 2);
AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex);
if (inputArgument != INPUT_CHANNEL) {
continue;
Expand All @@ -145,27 +143,9 @@ else if (argumentType.getJavaType().equals(double.class)) {
}
else {
valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType);
valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex)));
valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethod.type().parameterType(parameterIndex)));
}
inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter);

// move the position argument to the end (and combine with other existing position argument)
inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class);

ArrayList<Integer> reorder;
if (hasInputChannel) {
reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount);
}
else {
inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class);
reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount;
reorder.remove(positionParameterIndex);
reorder.add(parameterIndex + 1, positionParameterIndex);
hasInputChannel = true;
}
inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray());
}
return inputMethod;
}
Expand Down
Loading