diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/Accumulator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/Accumulator.java index 34d2f9338602..6edae8fd67a5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/Accumulator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/Accumulator.java @@ -17,15 +17,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import java.util.Optional; - public interface Accumulator { long getEstimatedSize(); Accumulator copy(); - void addInput(Page arguments, Optional mask); + void addInput(Page arguments, AggregationMask mask); void addIntermediate(Block block); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index dd1f56c53a3c..c729239ebf60 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableList; import io.airlift.bytecode.BytecodeBlock; -import io.airlift.bytecode.BytecodeNode; import io.airlift.bytecode.ClassDefinition; import io.airlift.bytecode.DynamicClassLoader; import io.airlift.bytecode.FieldDefinition; @@ -67,12 +66,14 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; -import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull; import static io.airlift.bytecode.expression.BytecodeExpressions.constantString; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; import static io.airlift.bytecode.expression.BytecodeExpressions.not; +import static io.trino.operator.aggregation.AggregationLoopBuilder.buildLoop; +import static io.trino.operator.aggregation.AggregationLoopBuilder.toGroupedInputFunction; +import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder; import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD; import static io.trino.sql.gen.BytecodeUtils.invoke; import static io.trino.sql.gen.BytecodeUtils.loadConstant; @@ -88,7 +89,8 @@ private AccumulatorCompiler() {} public static AccumulatorFactory generateAccumulatorFactory( BoundSignature boundSignature, AggregationImplementation implementation, - FunctionNullability functionNullability) + FunctionNullability functionNullability, + boolean specializedLoops) { // change types used in Aggregation methods to types used in the core Trino engine to simplify code generation implementation = normalizeAggregationMethods(implementation); @@ -103,19 +105,30 @@ public static AccumulatorFactory generateAccumulatorFactory( Accumulator.class, implementation, argumentNullable, - classLoader); + classLoader, + specializedLoops); Constructor groupedAccumulatorConstructor = generateAccumulatorClass( boundSignature, GroupedAccumulator.class, implementation, argumentNullable, - classLoader); + classLoader, + specializedLoops); + + List nonNullArguments = new ArrayList<>(); + for (int argumentIndex = 0; argumentIndex < argumentNullable.size(); argumentIndex++) { + if (!argumentNullable.get(argumentIndex)) { + nonNullArguments.add(argumentIndex); + } + } + Constructor maskBuilderConstructor = generateAggregationMaskBuilder(nonNullArguments.stream().mapToInt(Integer::intValue).toArray()); return new CompiledAccumulatorFactory( accumulatorConstructor, groupedAccumulatorConstructor, - implementation.getLambdaInterfaces()); + implementation.getLambdaInterfaces(), + maskBuilderConstructor); } private static Constructor generateAccumulatorClass( @@ -123,7 +136,8 @@ private static Constructor generateAccumulatorClass( Class accumulatorInterface, AggregationImplementation implementation, List argumentNullable, - DynamicClassLoader classLoader) + DynamicClassLoader classLoader, + boolean specializedLoops) { boolean grouped = accumulatorInterface == GroupedAccumulator.class; @@ -171,6 +185,7 @@ private static Constructor generateAccumulatorClass( generateAddInput( definition, + specializedLoops, stateFields, argumentNullable, lambdaProviderFields, @@ -179,6 +194,10 @@ private static Constructor generateAccumulatorClass( grouped); generateGetEstimatedSize(definition, stateFields); + if (grouped) { + generateSetGroupCount(definition, stateFields); + } + generateAddIntermediateAsCombine( definition, stateFieldAndDescriptors, @@ -332,8 +351,22 @@ private static void generateGetEstimatedSize(ClassDefinition definition, List stateFields) + { + Parameter groupCount = arg("groupCount", long.class); + + MethodDefinition method = definition.declareMethod(a(PUBLIC), "setGroupCount", type(void.class), groupCount); + BytecodeBlock body = method.getBody(); + for (FieldDefinition stateField : stateFields) { + BytecodeExpression state = method.getScope().getThis().getField(stateField); + body.append(state.invoke("ensureCapacity", void.class, groupCount)); + } + body.ret(); + } + private static void generateAddInput( ClassDefinition definition, + boolean specializedLoops, List stateField, List argumentNullable, List lambdaProviderFields, @@ -347,22 +380,17 @@ private static void generateAddInput( } Parameter arguments = arg("arguments", Page.class); parameters.add(arguments); - Parameter mask = arg("mask", Optional.class); + Parameter mask = arg("mask", AggregationMask.class); parameters.add(mask); MethodDefinition method = definition.declareMethod(a(PUBLIC), "addInput", type(void.class), parameters.build()); Scope scope = method.getScope(); BytecodeBlock body = method.getBody(); - if (grouped) { - generateEnsureCapacity(scope, stateField, body); - } - List parameterVariables = new ArrayList<>(); for (int i = 0; i < argumentNullable.size(); i++) { parameterVariables.add(scope.declareVariable(Block.class, "block" + i)); } - Variable masksBlock = scope.declareVariable("masksBlock", body, mask.invoke("orElse", Object.class, constantNull(Object.class)).cast(Block.class)); // Get all parameter blocks for (int i = 0; i < parameterVariables.size(); i++) { @@ -371,14 +399,13 @@ private static void generateAddInput( } BytecodeBlock block = generateInputForLoop( - arguments, + specializedLoops, stateField, - argumentNullable, inputFunction, scope, parameterVariables, lambdaProviderFields, - masksBlock, + mask, callSiteBinder, grouped); @@ -468,10 +495,6 @@ private static List getInvokeFunctionOnWindowIndexParameters // input parameters for (int i = 0; i < inputParameterCount; i++) { expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position)); - } - - // position parameter - if (inputParameterCount > 0) { expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position)); } @@ -485,72 +508,93 @@ private static List getInvokeFunctionOnWindowIndexParameters } private static BytecodeBlock generateInputForLoop( - Variable arguments, + boolean specializedLoops, List stateField, - List argumentNullable, MethodHandle inputFunction, Scope scope, List parameterVariables, List lambdaProviderFields, - Variable masksBlock, + Variable mask, CallSiteBinder callSiteBinder, boolean grouped) { + if (specializedLoops) { + if (grouped) { + inputFunction = toGroupedInputFunction(inputFunction, stateField.size()); + } + BytecodeBlock newBlock = new BytecodeBlock(); + Variable thisVariable = scope.getThis(); + + MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterVariables.size() + (grouped ? 1 : 0), lambdaProviderFields.size()); + + ImmutableList.Builder parameters = ImmutableList.builder(); + parameters.add(mask); + for (FieldDefinition fieldDefinition : stateField) { + parameters.add(thisVariable.getField(fieldDefinition)); + } + if (grouped) { + parameters.add(scope.getVariable("groupIdsBlock")); + } + parameters.addAll(parameterVariables); + for (FieldDefinition lambdaProviderField : lambdaProviderFields) { + parameters.add(scope.getThis().getField(lambdaProviderField) + .invoke("get", Object.class)); + } + + newBlock.append(invoke(callSiteBinder.bind(mainLoop), "mainLoop", parameters.build())); + return newBlock; + } + // For-loop over rows Variable positionVariable = scope.declareVariable(int.class, "position"); Variable rowsVariable = scope.declareVariable(int.class, "rows"); + Variable selectedPositionsArrayVariable = scope.declareVariable(int[].class, "selectedPositionsArray"); + Variable selectedPositionVariable = scope.declareVariable(int.class, "selectedPosition"); BytecodeBlock block = new BytecodeBlock() - .append(arguments) - .invokeVirtual(Page.class, "getPositionCount", int.class) - .putVariable(rowsVariable) + .initializeVariable(rowsVariable) + .initializeVariable(selectedPositionVariable) .initializeVariable(positionVariable); - BytecodeNode loopBody = generateInvokeInputFunction( - scope, - stateField, - positionVariable, - parameterVariables, - lambdaProviderFields, - inputFunction, - callSiteBinder, - grouped); - - // Wrap with null checks - for (int parameterIndex = 0; parameterIndex < parameterVariables.size(); parameterIndex++) { - if (!argumentNullable.get(parameterIndex)) { - Variable variableDefinition = parameterVariables.get(parameterIndex); - loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()) - .condition(new BytecodeBlock() - .getVariable(variableDefinition) - .getVariable(positionVariable) - .invokeInterface(Block.class, "isNull", boolean.class, int.class)) - .ifFalse(loopBody); - } - } - - loopBody = new IfStatement("if(testMask(%s, position))", masksBlock.getName()) - .condition(new BytecodeBlock() - .getVariable(masksBlock) - .getVariable(positionVariable) - .invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)) - .ifTrue(loopBody); - - ForLoop forLoop = new ForLoop() - .initialize(new BytecodeBlock().putVariable(positionVariable, 0)) - .condition(new BytecodeBlock() - .getVariable(positionVariable) - .getVariable(rowsVariable) - .invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)) + ForLoop selectAllLoop = new ForLoop() + .initialize(new BytecodeBlock() + .append(rowsVariable.set(mask.invoke("getPositionCount", int.class))) + .append(positionVariable.set(constantInt(0)))) + .condition(BytecodeExpressions.lessThan(positionVariable, rowsVariable)) .update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)) - .body(loopBody); - - block.append(new IfStatement("if(!maskGuaranteedToFilterAllRows(%s, %s))", rowsVariable.getName(), masksBlock.getName()) - .condition(new BytecodeBlock() - .getVariable(rowsVariable) - .getVariable(masksBlock) - .invokeStatic(AggregationUtils.class, "maskGuaranteedToFilterAllRows", boolean.class, int.class, Block.class)) - .ifFalse(forLoop)); + .body(generateInvokeInputFunction( + scope, + stateField, + positionVariable, + parameterVariables, + lambdaProviderFields, + inputFunction, + callSiteBinder, + grouped)); + + ForLoop selectedPositionsLoop = new ForLoop() + .initialize(new BytecodeBlock() + .append(rowsVariable.set(mask.invoke("getSelectedPositionCount", int.class))) + .append(selectedPositionsArrayVariable.set(mask.invoke("getSelectedPositions", int[].class))) + .append(positionVariable.set(constantInt(0)))) + .condition(BytecodeExpressions.lessThan(positionVariable, rowsVariable)) + .update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)) + .body(new BytecodeBlock() + .append(selectedPositionVariable.set(selectedPositionsArrayVariable.getElement(positionVariable))) + .append(generateInvokeInputFunction( + scope, + stateField, + selectedPositionVariable, + parameterVariables, + lambdaProviderFields, + inputFunction, + callSiteBinder, + grouped))); + + block.append(new IfStatement() + .condition(mask.invoke("isSelectAll", boolean.class)) + .ifTrue(selectAllLoop) + .ifFalse(selectedPositionsLoop)); return block; } @@ -568,7 +612,7 @@ private static BytecodeBlock generateInvokeInputFunction( BytecodeBlock block = new BytecodeBlock(); if (grouped) { - generateSetGroupIdFromGroupIdsBlock(scope, stateField, block); + generateSetGroupIdFromGroupIdsBlock(scope, stateField, block, position); } block.comment("Call input function with unpacked Block arguments"); @@ -581,10 +625,8 @@ private static BytecodeBlock generateInvokeInputFunction( } // input parameters - parameters.addAll(parameterVariables); - - // position parameter - if (!parameterVariables.isEmpty()) { + for (Variable variable : parameterVariables) { + parameters.add(variable); parameters.add(position); } @@ -659,10 +701,6 @@ private static void generateAddIntermediateAsCombine( .map(StateFieldAndDescriptor::getStateField) .collect(toImmutableList()); - if (grouped) { - generateEnsureCapacity(scope, stateFields, body); - } - BytecodeBlock loopBody = new BytecodeBlock(); loopBody.comment("combine(state_0, state_1, ... scratchState_0, scratchState_1, ... lambda_0, lambda_1, ...)"); @@ -697,25 +735,15 @@ private static void generateAddIntermediateAsCombine( .ret(); } - private static void generateSetGroupIdFromGroupIdsBlock(Scope scope, List stateFields, BytecodeBlock block) + private static void generateSetGroupIdFromGroupIdsBlock(Scope scope, List stateFields, BytecodeBlock block, Variable position) { Variable groupIdsBlock = scope.getVariable("groupIdsBlock"); - Variable position = scope.getVariable("position"); for (FieldDefinition stateField : stateFields) { BytecodeExpression state = scope.getThis().getField(stateField); block.append(state.invoke("setGroupId", void.class, groupIdsBlock.invoke("getGroupId", long.class, position))); } } - private static void generateEnsureCapacity(Scope scope, List stateFields, BytecodeBlock block) - { - Variable groupIdsBlock = scope.getVariable("groupIdsBlock"); - for (FieldDefinition stateField : stateFields) { - BytecodeExpression state = scope.getThis().getField(stateField); - block.append(state.invoke("ensureCapacity", void.class, groupIdsBlock.invoke("getGroupCount", long.class))); - } - } - private static MethodDefinition declareAddIntermediate(ClassDefinition definition, boolean grouped) { ImmutableList.Builder parameters = ImmutableList.builder(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java index 374ce655a2c2..8aff4a287419 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java @@ -27,4 +27,6 @@ public interface AccumulatorFactory GroupedAccumulator createGroupedAccumulator(List> lambdaProviders); GroupedAccumulator createGroupedIntermediateAccumulator(List> lambdaProviders); + + AggregationMaskBuilder createAggregationMaskBuilder(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java index 84d20bfddf86..e79ae02f5df5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java @@ -22,8 +22,6 @@ import java.lang.invoke.MethodType; import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -34,7 +32,6 @@ import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static java.lang.invoke.MethodHandles.collectArguments; import static java.lang.invoke.MethodHandles.lookup; -import static java.lang.invoke.MethodHandles.permuteArguments; import static java.util.Objects.requireNonNull; public final class AggregationFunctionAdapter @@ -103,7 +100,6 @@ public static MethodHandle normalizeInputMethod( List inputArgumentKinds = parameterKinds.stream() .filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL) .collect(toImmutableList()); - boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL); checkArgument( boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), @@ -113,19 +109,21 @@ public static MethodHandle normalizeInputMethod( List expectedInputArgumentKinds = new ArrayList<>(); expectedInputArgumentKinds.addAll(stateArgumentKinds); - expectedInputArgumentKinds.addAll(inputArgumentKinds); - if (hasInputChannel) { - expectedInputArgumentKinds.add(BLOCK_INDEX); + for (AggregationParameterKind kind : inputArgumentKinds) { + expectedInputArgumentKinds.add(kind); + if (kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL) { + expectedInputArgumentKinds.add(BLOCK_INDEX); + } } + checkArgument( expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds); - MethodType inputMethodType = inputMethod.type(); for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) { - int parameterIndex = stateArgumentKinds.size() + argumentIndex; + int parameterIndex = stateArgumentKinds.size() + (argumentIndex * 2); AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex); if (inputArgument != INPUT_CHANNEL) { continue; @@ -145,27 +143,9 @@ else if (argumentType.getJavaType().equals(double.class)) { } else { valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType); - valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex))); + valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethod.type().parameterType(parameterIndex))); } inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter); - - // move the position argument to the end (and combine with other existing position argument) - inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class); - - ArrayList reorder; - if (hasInputChannel) { - reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new)); - reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount); - } - else { - inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class); - reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new)); - int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount; - reorder.remove(positionParameterIndex); - reorder.add(parameterIndex + 1, positionParameterIndex); - hasInputChannel = true; - } - inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray()); } return inputMethod; } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java new file mode 100644 index 000000000000..d1de1391ca40 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java @@ -0,0 +1,406 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableList; +import io.trino.operator.GroupByIdBlock; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.GroupedAccumulatorState; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.ArrayDeque; +import java.util.List; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.cycle; +import static com.google.common.collect.Iterables.limit; +import static java.lang.invoke.MethodHandles.collectArguments; +import static java.lang.invoke.MethodHandles.dropArguments; +import static java.lang.invoke.MethodHandles.explicitCastArguments; +import static java.lang.invoke.MethodHandles.guardWithTest; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodHandles.permuteArguments; +import static java.lang.invoke.MethodHandles.publicLookup; + +final class AggregationLoopBuilder +{ + private static final MethodHandle GET_GROUP_ID; + private static final MethodHandle SET_GROUP_ID; + + private static final MethodHandle MASK_IS_SELECT_ALL; + private static final MethodHandle MASK_GET_POSITION_COUNT; + private static final MethodHandle MASK_GET_SELECTED_POSITION_COUNT; + private static final MethodHandle MASK_GET_SELECTED_POSITIONS; + + private static final MethodHandle RLE_GET_VALUE; + + private static final MethodHandle DICTIONARY_GET_DICTIONARY; + private static final MethodHandle DICTIONARY_GET_RAW_IDS; + private static final MethodHandle DICTIONARY_GET_RAW_IDS_OFFSET; + + private static final MethodHandle IS_INSTANCE; + private static final MethodHandle INT_ADD_FUNCTION; + private static final MethodHandle INT_INCREMENT; + private static final MethodHandle INT_GREATER_THAN; + + static { + try { + GET_GROUP_ID = publicLookup().findVirtual(GroupByIdBlock.class, "getGroupId", MethodType.methodType(long.class, int.class)); + SET_GROUP_ID = publicLookup().findVirtual(GroupedAccumulatorState.class, "setGroupId", MethodType.methodType(void.class, long.class)); + + MASK_GET_POSITION_COUNT = publicLookup().findVirtual(AggregationMask.class, "getPositionCount", MethodType.methodType(int.class)); + MASK_IS_SELECT_ALL = publicLookup().findVirtual(AggregationMask.class, "isSelectAll", MethodType.methodType(boolean.class)); + MASK_GET_SELECTED_POSITION_COUNT = publicLookup().findVirtual(AggregationMask.class, "getSelectedPositionCount", MethodType.methodType(int.class)); + MASK_GET_SELECTED_POSITIONS = publicLookup().findVirtual(AggregationMask.class, "getSelectedPositions", MethodType.methodType(int[].class)); + + RLE_GET_VALUE = lookup().findVirtual(RunLengthEncodedBlock.class, "getValue", MethodType.methodType(Block.class)); + + DICTIONARY_GET_DICTIONARY = lookup().findVirtual(DictionaryBlock.class, "getDictionary", MethodType.methodType(Block.class)); + DICTIONARY_GET_RAW_IDS = lookup().findVirtual(DictionaryBlock.class, "getRawIds", MethodType.methodType(int[].class)); + DICTIONARY_GET_RAW_IDS_OFFSET = lookup().findVirtual(DictionaryBlock.class, "getRawIdsOffset", MethodType.methodType(int.class)); + + IS_INSTANCE = publicLookup().findVirtual(Class.class, "isInstance", MethodType.methodType(boolean.class, Object.class)); + INT_ADD_FUNCTION = lookup().findStatic(AggregationLoopBuilder.class, "add", MethodType.methodType(int.class, int.class, int.class)); + INT_INCREMENT = lookup().findStatic(AggregationLoopBuilder.class, "increment", MethodType.methodType(int.class, int.class)); + INT_GREATER_THAN = lookup().findStatic(AggregationLoopBuilder.class, "greaterThan", MethodType.methodType(boolean.class, int.class, int.class)); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + private AggregationLoopBuilder() {} + + /** + * Converts ungrouped input function into an input function that has an additional parameter for the GroupIdBlock that + * is used to set the group id on each aggregation state. + */ + public static MethodHandle toGroupedInputFunction(MethodHandle inputFunction, int stateCount) + { + // input function will be: + // state0, state1, ..., stateN, block0, position0, block1, position1, ..., blockN, positionN + + // convert unused boolean parameter into a call to set the groupId on the groupedState + // state0, groupedState0, groupId, state1, groupedState1, groupId, ..., stateN, groupedStateN, groupId, block0, position0, block1, position1, ..., blockN, positionN + for (int i = 0; i < stateCount; i++) { + inputFunction = collectArguments(inputFunction, 1 + (i * 3), SET_GROUP_ID); + } + + // cast duped state arguments + // state0, state0, groupId, state1, state1, groupId, ..., stateN, stateN, groupId, block0, position0, block1, position1, ..., blockN, positionN + MethodType type = inputFunction.type(); + for (int i = 0; i < stateCount; i++) { + type = type.changeParameterType(1 + (i * 3), AccumulatorState.class); + } + inputFunction = explicitCastArguments(inputFunction, type); + + // deduplicate groupId and state arguments + // state0, state1, ..., stateN, groupId, block0, position0, block1, position1, ..., blockN, positionN + int[] reorder = new int[inputFunction.type().parameterCount()]; + for (int stateIndex = 0; stateIndex < stateCount; stateIndex++) { + reorder[(stateIndex * 3)] = stateIndex; + reorder[(stateIndex * 3) + 1] = stateIndex; + reorder[(stateIndex * 3) + 2] = stateCount; + } + int nonStateParameters = inputFunction.type().parameterCount() - (stateCount * 3); + for (int i = 0; i < nonStateParameters; i++) { + reorder[(stateCount * 3) + i] = stateCount + 1 + i; + } + MethodType newType = inputFunction.type(); + for (int i = 0; i < stateCount; i++) { + newType = newType.dropParameterTypes(i + 1, i + 3); + } + newType = newType.insertParameterTypes(stateCount, long.class); + inputFunction = permuteArguments(inputFunction, newType, reorder); + + // get groupId from GroupIdBlock + // state0, state1, ..., stateN, groupIdBlock, groupIdPosition, block0, position0, block1, position1, ..., blockN, positionN + inputFunction = collectArguments(inputFunction, stateCount, GET_GROUP_ID); + + // cast groupId block + // state0, state1, ..., stateN, groupIdBlock, groupIdPosition, block0, position0, block1, position1, ..., blockN, positionN + inputFunction = explicitCastArguments(inputFunction, inputFunction.type().changeParameterType(stateCount, Block.class)); + + return inputFunction; + } + + /** + * Build a loop over the aggregation function. Internally, there are multiple loops generated that are specialized for + * RLE, Dictionary, and basic blocks, and for masked or unmasked input. The method handle is expected to have a {@link Block} and int + * position argument for each parameter. The returned method handle signature, will start with as @link {@link AggregationMask} + * and then a single {@link Block} for each parameter. + */ + public static MethodHandle buildLoop(MethodHandle function, int stateCount, int parameterCount, int lambdaParameterCount) + { + // verify signature + List> expectedParameterTypes = ImmutableList.>builder() + .addAll(function.type().parameterList().subList(0, stateCount)) + .addAll(limit(cycle(Block.class, int.class), parameterCount * 2)) + .addAll(function.type().parameterList().subList(stateCount + (parameterCount * 2), function.type().parameterCount())) + .build(); + MethodType expectedSignature = MethodType.methodType(void.class, expectedParameterTypes); + checkArgument(function.type().equals(expectedSignature), "Expected function signature to be %s, but is %s", expectedSignature, function.type()); + + // add dummy loop position to the front of the loop + function = dropArguments(function, 0, int.class); + + function = buildParameterBlockTypeSelector(function, stateCount, lambdaParameterCount, new ArrayDeque<>(parameterCount), parameterCount); + return function; + } + + /** + * Builds a method handle that switches based on the parameter block type, and then delegates to a loop for the specific block type. + * This structure is build recursively in a bottoms up style. + */ + private static MethodHandle buildParameterBlockTypeSelector(MethodHandle function, int stateCount, int lambdaParameterCount, ArrayDeque currentTypes, int remainingParameters) + { + // If there are no more parameters, build the core loop for the selected block types + if (remainingParameters == 0) { + return buildCoreLoop(function, stateCount, ImmutableList.copyOf(currentTypes)); + } + + // otherwise, recurse and build a block type specific loop for each block type + + // RLE + currentTypes.addLast(BlockType.RLE); + MethodHandle rleMethodHandle = buildParameterBlockTypeSelector(function, stateCount, lambdaParameterCount, currentTypes, remainingParameters - 1); + rleMethodHandle = preprocessRleParameter( + rleMethodHandle, + rleMethodHandle.type().parameterCount() - lambdaParameterCount - remainingParameters); + currentTypes.removeLast(); + + // Dictionary + currentTypes.addLast(BlockType.DICTIONARY); + MethodHandle dictionaryMethodHandle = buildParameterBlockTypeSelector(function, stateCount, lambdaParameterCount, currentTypes, remainingParameters - 1); + dictionaryMethodHandle = preprocessDictionaryParameter( + dictionaryMethodHandle, + dictionaryMethodHandle.type().parameterCount() - lambdaParameterCount - remainingParameters - 2); + currentTypes.removeLast(); + + // Basic + currentTypes.addLast(BlockType.BASIC); + MethodHandle basicMethodHandle = buildParameterBlockTypeSelector(function, stateCount, lambdaParameterCount, currentTypes, remainingParameters - 1); + currentTypes.removeLast(); + + // combine the block type specific loops into a single method handle that selects the correct loop based on block type. + return guardWithTest( + blockParameterIndexOf( + basicMethodHandle.type(), + basicMethodHandle.type().parameterCount() - lambdaParameterCount - remainingParameters, + RunLengthEncodedBlock.class), + rleMethodHandle, + guardWithTest( + blockParameterIndexOf( + basicMethodHandle.type(), + basicMethodHandle.type().parameterCount() - lambdaParameterCount - remainingParameters, + DictionaryBlock.class), + dictionaryMethodHandle, + basicMethodHandle)); + } + + /** + * Builds the core loop. Parameters of input function are processed for a specific block type. Then + * the function is wrapped into two separate loops: one that process all positions, and one that processes + * only the selected positions in the mask. Finally, these loops are wrapped in an if statement to + * select the correct loop based on the @see {@link AggregationMask} state. + */ + private static MethodHandle buildCoreLoop(MethodHandle function, int stateCount, List blockTypes) + { + // for each parameter, process for the specified block type + for (int parameterIndex = blockTypes.size() - 1; parameterIndex >= 0; parameterIndex--) { + int parameterStart = (parameterIndex * 2) + 1 + stateCount; + function = switch (blockTypes.get(parameterIndex)) { + case RLE -> processRleParameter(parameterStart, function); + case DICTIONARY -> processDictionaryParameter(parameterStart, function); + case BASIC -> processBasicParameter(parameterStart, function); + }; + } + + MethodHandle selectAllLoop = selectAllLoop(function); + MethodHandle maskedLoop = maskedLoop(function); + + return guardWithTest(MASK_IS_SELECT_ALL, selectAllLoop, maskedLoop); + } + + /** + * Process the function on every position. + */ + private static MethodHandle selectAllLoop(MethodHandle function) + { + // add unused iterationCount argument in the second position, which required by countedLoop. + function = dropArguments(function, 1, AggregationMask.class); + + return MethodHandles.countedLoop(MASK_GET_POSITION_COUNT, null, function); + } + + /** + * Process the function on only the selected position. + */ + private static MethodHandle maskedLoop(MethodHandle function) + { + // There are 4 loop clauses + // 1. selectedPositionCount: constant int -- exits the loop when index is >= to selected position + // 2. selectedPositions: constant int[] + // 3. ---- execute input function + // 4. index: an int incremented on each step + // Each loop clause, can define a variable before the loop, and then within the loop, each clause can + // update the variable, and then check if the loop should be updated. Each clause is executed in order + // in the body. This is why index is tested in the first clause, and updated in the last clause. + // + // All update and test method, must start with the parameters: int selectedPositionCount, int[] selectedPositions, int index + + MethodHandle selectedPositionStopCondition = dropArguments(INT_GREATER_THAN, 1, int[].class); + + MethodHandle arrayElementGetter = MethodHandles.arrayElementGetter(int[].class); + function = collectArguments(function, 0, arrayElementGetter); + function = dropArguments(function, 2, AggregationMask.class); + function = dropArguments(function, 0, int.class); + + MethodHandle increment = dropArguments(INT_INCREMENT, 0, int.class, int[].class); + + MethodHandle loop = MethodHandles.loop( + new MethodHandle[]{MASK_GET_SELECTED_POSITION_COUNT, null, selectedPositionStopCondition}, + new MethodHandle[]{MASK_GET_SELECTED_POSITIONS}, + new MethodHandle[]{null, function}, + new MethodHandle[]{null, increment}); + + return loop; + } + + private static MethodHandle preprocessRleParameter(MethodHandle rleMethodHandle, int parameterIndex) + { + // read the rleValue from a RunLengthEncodedBlock + rleMethodHandle = collectArguments( + rleMethodHandle, + parameterIndex, + RLE_GET_VALUE); + rleMethodHandle = explicitCastArguments(rleMethodHandle, rleMethodHandle.type().changeParameterType(parameterIndex, Block.class)); + return rleMethodHandle; + } + + /** + * For inner loop of RLE parameter, hard code the position to 0. + */ + private static MethodHandle processRleParameter(int parameterStart, MethodHandle function) + { + return MethodHandles.insertArguments(function, parameterStart + 1, 0); + } + + /** + * Outside of loop for RLE parameter, replace the RLE block with the RLE value + */ + private static MethodHandle preprocessDictionaryParameter(MethodHandle dictionaryMethodHandle, int dictionaryParameterIndex) + { + // starting method handle type + // (int position, otherParams..., Block dictionary, int[] rawIds, int rawIdsOffset + + // read the dictionary, rawIds, and rawIdsOffset from a DictionaryBlock + // (int position, otherParams..., DictionaryBlock block, DictionaryBlock block, DictionaryBlock block + dictionaryMethodHandle = collectArguments(dictionaryMethodHandle, dictionaryParameterIndex, DICTIONARY_GET_DICTIONARY); + dictionaryMethodHandle = collectArguments(dictionaryMethodHandle, dictionaryParameterIndex + 1, DICTIONARY_GET_RAW_IDS); + dictionaryMethodHandle = collectArguments(dictionaryMethodHandle, dictionaryParameterIndex + 2, DICTIONARY_GET_RAW_IDS_OFFSET); + + // consolidate the 3 dictionary block parameters into one + // (int position, otherParams..., DictionaryBlock block) + int[] reorder = IntStream.range(0, dictionaryMethodHandle.type().parameterCount()) + .map(i -> i < dictionaryParameterIndex + 2 ? i : i - 2) + .toArray(); + reorder[dictionaryParameterIndex + 1] = dictionaryParameterIndex; + reorder[dictionaryParameterIndex + 2] = dictionaryParameterIndex; + MethodType newType = dictionaryMethodHandle.type().dropParameterTypes(dictionaryParameterIndex, dictionaryParameterIndex + 2); + dictionaryMethodHandle = permuteArguments(dictionaryMethodHandle, newType, reorder); + dictionaryMethodHandle = explicitCastArguments(dictionaryMethodHandle, dictionaryMethodHandle.type().changeParameterType(dictionaryParameterIndex, Block.class)); + return dictionaryMethodHandle; + } + + /** + * For inner loop of dictionary parameter, replace positions with {@code rawIds[position + rawIdsOffset]} + */ + private static MethodHandle processDictionaryParameter(int parameterStart, MethodHandle function) + { + function = collectArguments(function, parameterStart + 1, MethodHandles.arrayElementGetter(int[].class)); + function = collectArguments(function, parameterStart + 2, INT_ADD_FUNCTION); + MethodHandle methodHandle = mergeParameterWithPositionArgument(parameterStart + 2, function); + return methodHandle; + } + + /** + * Outside of loop for RLE parameter, replace the Dictionary block with three parameters: the dictionary, rawIds array, and rawIdsOffset. + */ + private static MethodHandle processBasicParameter(int parameterStart, MethodHandle function) + { + return mergeParameterWithPositionArgument(parameterStart + 1, function); + } + + /** + * Merges the specified parameter with the loop position argument, which is always the first parameter For example, + *
{@code
+     * (int position, Block a, int aPosition, Block b, int bPositions)
+     * }
+ * with {@code parameterToMerge} set to 2, returns: + *
{@code
+     * (int position, Block a, Block b, int bPositions)
+     * }
+ */ + private static MethodHandle mergeParameterWithPositionArgument(int parameterToMerge, MethodHandle methodHandle) + { + int[] reorder = IntStream.range(0, methodHandle.type().parameterCount()) + .map(i -> i < parameterToMerge + 1 ? i : i - 1) + .toArray(); + reorder[parameterToMerge] = 0; + MethodType newType = methodHandle.type().dropParameterTypes(parameterToMerge, parameterToMerge + 1); + return permuteArguments(methodHandle, newType, reorder); + } + + /** + * Returns a method handle that test if the specified position of the methodType is of the specified block type. + */ + private static MethodHandle blockParameterIndexOf(MethodType methodType, int position, Class blockType) + { + verify(methodType.parameterType(position).equals(Block.class)); + MethodHandle instanceOf = IS_INSTANCE.bindTo(blockType); + instanceOf = dropArguments(instanceOf, 0, methodType.parameterList().subList(0, position)); + instanceOf = dropArguments(instanceOf, position + 1, methodType.parameterList().subList(position + 1, methodType.parameterCount())); + instanceOf = explicitCastArguments(instanceOf, methodType.changeReturnType(boolean.class)); + return instanceOf; + } + + // helper methods used for method handle combinations + private static int add(int left, int right) + { + return left + right; + } + + private static int increment(int index) + { + return index + 1; + } + + private static boolean greaterThan(int selectedPositionCount, int index) + { + return selectedPositionCount > index; + } + + private enum BlockType + { + RLE, DICTIONARY, BASIC + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMask.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMask.java new file mode 100644 index 000000000000..8d0a8bda0587 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMask.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; + +import javax.annotation.Nullable; + +import java.util.Arrays; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public final class AggregationMask +{ + private static final int[] NO_SELECTED_POSITIONS = new int[0]; + + private int positionCount; + private int[] selectedPositions = NO_SELECTED_POSITIONS; + private int selectedPositionCount; + + public static AggregationMask createSelectNone(int positionCount) + { + return createSelectedPositions(positionCount, NO_SELECTED_POSITIONS, 0); + } + + public static AggregationMask createSelectAll(int positionCount) + { + return new AggregationMask(positionCount); + } + + public static AggregationMask createSelectedPositions(int positionCount, int[] selectedPositions, int selectedPositionCount) + { + return new AggregationMask(positionCount, selectedPositions, selectedPositionCount); + } + + private AggregationMask(int positionCount) + { + reset(positionCount); + } + + private AggregationMask(int positionCount, int[] selectedPositions, int selectedPositionCount) + { + checkArgument(positionCount >= 0, "positionCount is negative"); + checkArgument(selectedPositionCount >= 0, "selectedPositionCount is negative"); + checkArgument(selectedPositionCount <= positionCount, "selectedPositionCount cannot be greater than positionCount"); + requireNonNull(selectedPositions, "selectedPositions is null"); + checkArgument(selectedPositions.length >= selectedPositionCount, "selectedPosition is smaller than selectedPositionCount"); + + reset(positionCount); + this.selectedPositions = selectedPositions; + this.selectedPositionCount = selectedPositionCount; + } + + public void reset(int positionCount) + { + checkArgument(positionCount >= 0, "positionCount is negative"); + this.positionCount = positionCount; + this.selectedPositionCount = positionCount; + } + + public int getPositionCount() + { + return positionCount; + } + + public boolean isSelectAll() + { + return positionCount == selectedPositionCount; + } + + public boolean isSelectNone() + { + return selectedPositionCount == 0; + } + + public Page filterPage(Page page) + { + if (isSelectAll()) { + return page; + } + if (isSelectNone()) { + return page.getRegion(0, 0); + } + return page.getPositions(Arrays.copyOf(selectedPositions, selectedPositionCount), 0, selectedPositionCount); + } + + /** + * Do not use this to filter a page, as the underlying array can change, and this will change the page after the filtering. + */ + public int getSelectedPositionCount() + { + return selectedPositionCount; + } + + public int[] getSelectedPositions() + { + checkState(!isSelectAll(), "getSelectedPositions not available when in selectAll mode"); + return selectedPositions; + } + + public void unselectNullPositions(Block block) + { + unselectPositions(block, false); + } + + public void applyMaskBlock(@Nullable Block maskBlock) + { + if (maskBlock != null) { + unselectPositions(maskBlock, true); + } + } + + private void unselectPositions(Block block, boolean shouldTestValues) + { + int positionCount = block.getPositionCount(); + checkArgument(positionCount == this.positionCount, "Block position count does not match current position count"); + if (isSelectNone()) { + return; + } + + // short circuit if there are no nulls, and we are not testing the value + if (!block.mayHaveNull() && !shouldTestValues) { + // all positions selected, so change nothing + return; + } + + if (block instanceof RunLengthEncodedBlock) { + if (test(block, 0, shouldTestValues)) { + // all positions selected, so change nothing + return; + } + // no positions selected + selectedPositionCount = 0; + return; + } + + if (positionCount == selectedPositionCount) { + if (selectedPositions.length < positionCount) { + selectedPositions = new int[positionCount]; + } + + // add all positions that pass the test + int selectedPositionsIndex = 0; + for (int position = 0; position < positionCount; position++) { + if (test(block, position, shouldTestValues)) { + selectedPositions[selectedPositionsIndex] = position; + selectedPositionsIndex++; + } + } + selectedPositionCount = selectedPositionsIndex; + return; + } + + // keep only the positions that pass the test + int originalIndex = 0; + int newIndex = 0; + for (; originalIndex < selectedPositionCount; originalIndex++) { + int position = selectedPositions[originalIndex]; + if (test(block, position, shouldTestValues)) { + selectedPositions[newIndex] = position; + newIndex++; + } + } + selectedPositionCount = newIndex; + } + + private static boolean test(Block block, int position, boolean testValue) + { + if (block.isNull(position)) { + return false; + } + if (testValue && block.getByte(position, 0) == 0) { + return false; + } + return true; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskBuilder.java new file mode 100644 index 000000000000..e4640029b991 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskBuilder.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; + +import java.util.Optional; + +public interface AggregationMaskBuilder +{ + /** + * Create an AggregationMask that only selects positions that pass the specified + * mask block, and do not have null for non-null arguments. The returned mask + * can be further modified if desired, but it should not be used after the next + * call to this method. Internally implementations are allowed to reuse position + * arrays across multiple calls. + */ + AggregationMask buildAggregationMask(Page arguments, Optional optionalMaskBlock); +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskCompiler.java new file mode 100644 index 000000000000..bc7a225960d0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskCompiler.java @@ -0,0 +1,212 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableMap; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.FieldDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.ForLoop; +import io.airlift.bytecode.control.IfStatement; +import io.airlift.bytecode.expression.BytecodeExpression; +import io.trino.annotation.UsedByGeneratedCode; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthBlockEncoding; +import io.trino.spi.block.RunLengthEncodedBlock; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PRIVATE; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.and; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull; +import static io.airlift.bytecode.expression.BytecodeExpressions.equal; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.airlift.bytecode.expression.BytecodeExpressions.isNotNull; +import static io.airlift.bytecode.expression.BytecodeExpressions.isNull; +import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; +import static io.airlift.bytecode.expression.BytecodeExpressions.newArray; +import static io.airlift.bytecode.expression.BytecodeExpressions.not; +import static io.airlift.bytecode.expression.BytecodeExpressions.notEqual; +import static io.airlift.bytecode.expression.BytecodeExpressions.or; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; + +public final class AggregationMaskCompiler +{ + private AggregationMaskCompiler() {} + + public static Constructor generateAggregationMaskBuilder(int... nonNullArgumentChannels) + { + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName(AggregationMaskBuilder.class.getSimpleName()), + type(Object.class), + type(AggregationMaskBuilder.class)); + + FieldDefinition selectedPositionsField = definition.declareField(a(PRIVATE), "selectedPositions", int[].class); + + MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); + constructor.getBody().comment("super();") + .append(constructor.getThis()) + .invokeConstructor(Object.class) + .append(constructor.getThis().setField(selectedPositionsField, newArray(type(int[].class), 0))) + .ret(); + + Parameter argumentsParameter = arg("arguments", type(Page.class)); + Parameter maskBlockParameter = arg("optionalMaskBlock", type(Optional.class, Block.class)); + MethodDefinition method = definition.declareMethod( + a(PUBLIC), + "buildAggregationMask", + type(AggregationMask.class), + argumentsParameter, + maskBlockParameter); + + BytecodeBlock body = method.getBody(); + Scope scope = method.getScope(); + Variable positionCount = scope.declareVariable("positionCount", body, argumentsParameter.invoke("getPositionCount", int.class)); + + // if page is empty, return select none + body.append(new IfStatement() + .condition(equal(positionCount, constantInt(0))) + .ifTrue(invokeStatic(AggregationMask.class, "createSelectNone", AggregationMask.class, positionCount).ret())); + + Variable maskBlock = scope.declareVariable("maskBlock", body, maskBlockParameter.invoke("orElse", Object.class, constantNull(Object.class)).cast(Block.class)); + Variable hasMaskBlock = scope.declareVariable("hasMaskBlock", body, isNotNull(maskBlock)); + Variable maskBlockMayHaveNull = scope.declareVariable( + "maskBlockMayHaveNull", + body, + and(hasMaskBlock, maskBlock.invoke("mayHaveNull", boolean.class))); + + // if mask is RLE it will be, either all allowed, or all denied + body.append(new IfStatement() + .condition(maskBlock.instanceOf(RunLengthBlockEncoding.class)) + .ifTrue(new BytecodeBlock() + .append(new IfStatement() + .condition(testMaskBlock( + maskBlock.cast(RunLengthEncodedBlock.class).invoke("getValue", Block.class), + maskBlockMayHaveNull, + constantInt(0))) + .ifTrue(invokeStatic(AggregationMask.class, "createSelectNone", AggregationMask.class, positionCount).ret())) + .append(hasMaskBlock.set(constantFalse())) + .append(maskBlockMayHaveNull.set(constantFalse())))); + + List nonNullArgs = new ArrayList<>(nonNullArgumentChannels.length); + List nonNullArgMayHaveNulls = new ArrayList<>(nonNullArgumentChannels.length); + for (int channel : nonNullArgumentChannels) { + Variable arg = scope.declareVariable("arg" + channel, body, argumentsParameter.invoke("getBlock", Block.class, constantInt(channel))); + body.append(new IfStatement() + .condition(invokeStatic(AggregationMaskCompiler.class, "isAlwaysNull", boolean.class, arg)) + .ifTrue(invokeStatic(AggregationMask.class, "createSelectNone", AggregationMask.class, positionCount).ret())); + Variable mayHaveNull = scope.declareVariable("arg" + channel + "MayHaveNull", body, arg.invoke("mayHaveNull", boolean.class)); + nonNullArgs.add(arg); + nonNullArgMayHaveNulls.add(mayHaveNull); + } + + // if there is no mask block, and all non-null arguments do not have nulls, return selectAll + BytecodeExpression isSelectAll = not(hasMaskBlock); + for (Variable mayHaveNull : nonNullArgMayHaveNulls) { + isSelectAll = and(isSelectAll, not(mayHaveNull)); + } + body.append(new IfStatement() + .condition(isSelectAll) + .ifTrue(invokeStatic(AggregationMask.class, "createSelectAll", AggregationMask.class, positionCount).ret())); + + // grow the selection array if necessary + Variable selectedPositions = scope.declareVariable("selectedPositions", body, method.getThis().getField(selectedPositionsField)); + body.append(new IfStatement() + .condition(lessThan(selectedPositions.length(), positionCount)) + .ifTrue(new BytecodeBlock() + .append(selectedPositions.set(newArray(type(int[].class), positionCount))) + .append(method.getThis().setField(selectedPositionsField, selectedPositions)))); + + // add all positions that pass the tests + Variable position = scope.declareVariable("position", body, constantInt(0)); + BytecodeExpression isPositionSelected = testMaskBlock(maskBlock, maskBlockMayHaveNull, position); + for (int i = 0; i < nonNullArgs.size(); i++) { + Variable arg = nonNullArgs.get(i); + Variable mayHaveNull = nonNullArgMayHaveNulls.get(i); + isPositionSelected = and(isPositionSelected, testPositionIsNotNull(arg, mayHaveNull, position)); + } + + Variable selectedPositionsIndex = scope.declareVariable("selectedPositionsIndex", body, constantInt(0)); + body.append(new ForLoop() + .condition(lessThan(position, positionCount)) + .update(position.increment()) + .body(new IfStatement() + .condition(isPositionSelected) + .ifTrue(new BytecodeBlock() + .append(selectedPositions.setElement(selectedPositionsIndex, position)) + .append(selectedPositionsIndex.increment())))); + + body.append(invokeStatic( + AggregationMask.class, + "createSelectedPositions", + AggregationMask.class, + positionCount, + selectedPositions, + selectedPositionsIndex) + .ret()); + + Class builderClass = defineClass( + definition, + AggregationMaskBuilder.class, + ImmutableMap.of(), + AggregationMaskCompiler.class.getClassLoader()); + + try { + return builderClass.getConstructor(); + } + catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + private static BytecodeExpression testPositionIsNotNull(BytecodeExpression block, BytecodeExpression mayHaveNulls, BytecodeExpression position) + { + return or(not(mayHaveNulls), not(block.invoke("isNull", boolean.class, position))); + } + + private static BytecodeExpression testMaskBlock(BytecodeExpression block, BytecodeExpression mayHaveNulls, BytecodeExpression position) + { + return or( + isNull(block), + and( + testPositionIsNotNull(block, mayHaveNulls, position), + notEqual(block.invoke("getByte", byte.class, position, constantInt(0)).cast(int.class), constantInt(0)))); + } + + @UsedByGeneratedCode + public static boolean isAlwaysNull(Block block) + { + if (block instanceof RunLengthEncodedBlock rle) { + return rle.getValue().isNull(0); + } + return false; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java index 45543c607509..b6b49fb294b2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java @@ -35,8 +35,16 @@ public class Aggregator private final Type finalType; private final int[] inputChannels; private final OptionalInt maskChannel; + private final AggregationMaskBuilder maskBuilder; - public Aggregator(Accumulator accumulator, Step step, Type intermediateType, Type finalType, List inputChannels, OptionalInt maskChannel) + public Aggregator( + Accumulator accumulator, + Step step, + Type intermediateType, + Type finalType, + List inputChannels, + OptionalInt maskChannel, + AggregationMaskBuilder maskBuilder) { this.accumulator = requireNonNull(accumulator, "accumulator is null"); this.step = requireNonNull(step, "step is null"); @@ -44,6 +52,7 @@ public Aggregator(Accumulator accumulator, Step step, Type intermediateType, Typ this.finalType = requireNonNull(finalType, "finalType is null"); this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null")); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); + this.maskBuilder = requireNonNull(maskBuilder, "maskBuilder is null"); checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); } @@ -58,21 +67,23 @@ public Type getType() public void processPage(Page page) { if (step.isInputRaw()) { - accumulator.addInput(page.getColumns(inputChannels), getMaskBlock(page)); + Page arguments = page.getColumns(inputChannels); + Optional maskBlock = Optional.empty(); + if (maskChannel.isPresent()) { + maskBlock = Optional.of(page.getBlock(maskChannel.getAsInt())); + } + AggregationMask mask = maskBuilder.buildAggregationMask(arguments, maskBlock); + + if (mask.isSelectNone()) { + return; + } + accumulator.addInput(arguments, mask); } else { accumulator.addIntermediate(page.getBlock(inputChannels[0])); } } - private Optional getMaskBlock(Page page) - { - if (maskChannel.isEmpty()) { - return Optional.empty(); - } - return Optional.of(page.getBlock(maskChannel.getAsInt())); - } - public void evaluate(BlockBuilder blockBuilder) { if (step.isOutputPartial()) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java index 968162f39347..057faab35c05 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java @@ -66,7 +66,7 @@ public Aggregator createAggregator() else { accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders); } - return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel); + return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder()); } public GroupedAggregator createGroupedAggregator() @@ -78,7 +78,7 @@ public GroupedAggregator createGroupedAggregator() else { accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } - return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel); + return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder()); } public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel) @@ -90,7 +90,7 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan else { accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } - return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel); + return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel, accumulatorFactory.createAggregationMaskBuilder()); } public boolean isSpillable() diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java index 71f32311d7c3..8200a225228f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java @@ -27,15 +27,18 @@ public class CompiledAccumulatorFactory private final Constructor accumulatorConstructor; private final Constructor groupedAccumulatorConstructor; private final List> lambdaInterfaces; + private final Constructor maskBuilderConstructor; public CompiledAccumulatorFactory( Constructor accumulatorConstructor, Constructor groupedAccumulatorConstructor, - List> lambdaInterfaces) + List> lambdaInterfaces, + Constructor maskBuilderConstructor) { this.accumulatorConstructor = requireNonNull(accumulatorConstructor, "accumulatorConstructor is null"); this.groupedAccumulatorConstructor = requireNonNull(groupedAccumulatorConstructor, "groupedAccumulatorConstructor is null"); this.lambdaInterfaces = ImmutableList.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); + this.maskBuilderConstructor = requireNonNull(maskBuilderConstructor, "maskBuilderConstructor is null"); } @Override @@ -87,4 +90,15 @@ public GroupedAccumulator createGroupedIntermediateAccumulator(List mask) + public void addInput(Page arguments, AggregationMask mask) { // 1. filter out positions based on mask, if present - Page filtered = mask - .map(maskBlock -> filter(arguments, maskBlock)) - .orElse(arguments); - - if (filtered.getPositionCount() == 0) { - return; - } + Page filtered = mask.filterPage(arguments); - // 2. compute a distinct mask + // 2. compute a distinct mask block Work work = hash.markDistinctRows(filtered); checkState(work.process()); Block distinctMask = work.getResult(); - // 3. feed a Page with a new mask to the underlying aggregation - accumulator.addInput(filtered, Optional.of(distinctMask)); + // 3. update original mask to the new distinct mask block + mask.reset(filtered.getPositionCount()); + mask.applyMaskBlock(distinctMask); + if (mask.isSelectNone()) { + return; + } + + // 4. feed a Page with a new mask to the underlying aggregation + accumulator.addInput(filtered, mask); } @Override @@ -210,21 +215,32 @@ public long getEstimatedSize() } @Override - public void addInput(GroupByIdBlock groupIdsBlock, Page page, Optional mask) + public void setGroupCount(long groupCount) + { + accumulator.setGroupCount(groupCount); + } + + @Override + public void addInput(GroupByIdBlock groupIdsBlock, Page page, AggregationMask mask) { Page withGroup = page.prependColumn(groupIdsBlock); // 1. filter out positions based on mask, if present - Page filteredWithGroup = mask - .map(maskBlock -> filter(withGroup, maskBlock)) - .orElse(withGroup); + Page filteredWithGroup = mask.filterPage(withGroup); // 2. compute a mask for the distinct rows (including the group id) Work work = hash.markDistinctRows(filteredWithGroup); checkState(work.process()); Block distinctMask = work.getResult(); - // 3. feed a Page with a new mask to the underlying aggregation + // 3. update original mask to the new distinct mask block + mask.reset(filteredWithGroup.getPositionCount()); + mask.applyMaskBlock(distinctMask); + if (mask.isSelectNone()) { + return; + } + + // 4. feed a Page with a new mask to the underlying aggregation GroupByIdBlock groupIds = new GroupByIdBlock(groupIdsBlock.getGroupCount(), filteredWithGroup.getBlock(0)); // drop the group id column and prepend the distinct mask column @@ -234,7 +250,7 @@ public void addInput(GroupByIdBlock groupIdsBlock, Page page, Optional ma } Page filtered = filteredWithGroup.getColumns(columnIndexes); // NOTE: the accumulator must be called even if the filtered page is empty to inform the accumulator about the group count - accumulator.addInput(groupIds, filtered, Optional.of(distinctMask)); + accumulator.addInput(groupIds, filtered, mask); } @Override @@ -258,30 +274,4 @@ public void evaluateFinal(int groupId, BlockBuilder output) @Override public void prepareFinal() {} } - - private static Page filter(Page page, Block mask) - { - int positions = mask.getPositionCount(); - if (positions > 0 && mask instanceof RunLengthEncodedBlock) { - // must have at least 1 position to be able to check the value at position 0 - if (!mask.isNull(0) && BOOLEAN.getBoolean(mask, 0)) { - return page; - } - return page.getPositions(new int[0], 0, 0); - } - boolean mayHaveNull = mask.mayHaveNull(); - int[] ids = new int[positions]; - int next = 0; - for (int i = 0; i < ids.length; ++i) { - boolean isNull = mayHaveNull && mask.isNull(i); - if (!isNull && BOOLEAN.getBoolean(mask, i)) { - ids[next++] = i; - } - } - - if (next == ids.length) { - return page; // no rows were eliminated by the filter - } - return page.getPositions(ids, 0, next); - } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAccumulator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAccumulator.java index c1706beaeeca..6ce19257bafd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAccumulator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAccumulator.java @@ -18,13 +18,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import java.util.Optional; - public interface GroupedAccumulator { long getEstimatedSize(); - void addInput(GroupByIdBlock groupIdsBlock, Page page, Optional mask); + void setGroupCount(long groupCount); + + void addInput(GroupByIdBlock groupIdsBlock, Page page, AggregationMask mask); void addIntermediate(GroupByIdBlock groupIdsBlock, Block block); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java index 9098f325d145..e0700e6b68df 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java @@ -37,8 +37,16 @@ public class GroupedAggregator private final Type finalType; private final int[] inputChannels; private final OptionalInt maskChannel; + private final AggregationMaskBuilder maskBuilder; - public GroupedAggregator(GroupedAccumulator accumulator, Step step, Type intermediateType, Type finalType, List inputChannels, OptionalInt maskChannel) + public GroupedAggregator( + GroupedAccumulator accumulator, + Step step, + Type intermediateType, + Type finalType, + List inputChannels, + OptionalInt maskChannel, + AggregationMaskBuilder maskBuilder) { this.accumulator = requireNonNull(accumulator, "accumulator is null"); this.step = requireNonNull(step, "step is null"); @@ -46,6 +54,7 @@ public GroupedAggregator(GroupedAccumulator accumulator, Step step, Type interme this.finalType = requireNonNull(finalType, "finalType is null"); this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null")); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); + this.maskBuilder = requireNonNull(maskBuilder, "maskBuilder is null"); checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); } @@ -64,22 +73,26 @@ public Type getType() public void processPage(GroupByIdBlock groupIds, Page page) { + accumulator.setGroupCount(groupIds.getGroupCount()); + if (step.isInputRaw()) { - accumulator.addInput(groupIds, page.getColumns(inputChannels), getMaskBlock(page)); + Page arguments = page.getColumns(inputChannels); + Optional maskBlock = Optional.empty(); + if (maskChannel.isPresent()) { + maskBlock = Optional.of(page.getBlock(maskChannel.getAsInt())); + } + AggregationMask mask = maskBuilder.buildAggregationMask(arguments, maskBlock); + + if (mask.isSelectNone()) { + return; + } + accumulator.addInput(groupIds, arguments, mask); } else { accumulator.addIntermediate(groupIds, page.getBlock(inputChannels[0])); } } - private Optional getMaskBlock(Page page) - { - if (maskChannel.isEmpty()) { - return Optional.empty(); - } - return Optional.of(page.getBlock(maskChannel.getAsInt())); - } - public void prepareFinal() { accumulator.prepareFinal(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java index 6c5de06b3659..a6d30889dd58 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java @@ -61,8 +61,9 @@ public static void input( @TypeParameter("V") Type valueType, @AggregationState({"K", "V"}) KeyValuePairsState state, @BlockPosition @SqlType("K") Block key, + @BlockIndex int keyPosition, @NullablePosition @BlockPosition @SqlType("V") Block value, - @BlockIndex int position) + @BlockIndex int valuePosition) { KeyValuePairs pairs = state.get(); if (pairs == null) { @@ -71,7 +72,7 @@ public static void input( } long startSize = pairs.estimatedInMemorySize(); - pairs.add(key, value, position, position); + pairs.add(key, value, keyPosition, valuePosition); state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java index 0c29d7421cfb..2b94e7084882 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java @@ -54,13 +54,14 @@ public static void input( @AggregationState("K") InOut keyState, @AggregationState("V") InOut valueState, @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockIndex int valuePosition, @BlockPosition @SqlType("K") Block keyBlock, - @BlockIndex int position) + @BlockIndex int keyPosition) throws Throwable { - if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) > 0) { - keyState.set(keyBlock, position); - valueState.set(valueBlock, position); + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, keyPosition, keyState)) > 0) { + keyState.set(keyBlock, keyPosition); + valueState.set(valueBlock, valuePosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java index b26648b0fa4b..cb5d6588dd0c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java @@ -54,13 +54,14 @@ public static void input( @AggregationState("K") InOut keyState, @AggregationState("V") InOut valueState, @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockIndex int valuePosition, @BlockPosition @SqlType("K") Block keyBlock, - @BlockIndex int position) + @BlockIndex int keyPosition) throws Throwable { - if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) < 0) { - keyState.set(keyBlock, position); - valueState.set(valueBlock, position); + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, keyPosition, keyState)) < 0) { + keyState.set(keyBlock, keyPosition); + valueState.set(valueBlock, valuePosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java index 5e752410ee06..c750b221bed8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java @@ -27,12 +27,10 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; -import java.util.Optional; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.BooleanType.BOOLEAN; import static java.lang.Long.max; import static java.util.Objects.requireNonNull; @@ -96,6 +94,12 @@ public GroupedAccumulator createGroupedIntermediateAccumulator(List mask) + public void addInput(Page page, AggregationMask mask) { - if (mask.isPresent()) { - page = filter(page, mask.orElseThrow()); - } - pagesIndex.addPage(page); + pagesIndex.addPage(mask.filterPage(page)); } @Override @@ -158,7 +159,11 @@ public void evaluateFinal(BlockBuilder blockBuilder) { pagesIndex.sort(orderByChannels, orderings); Iterator pagesIterator = pagesIndex.getSortedPages(); - pagesIterator.forEachRemaining(arguments -> accumulator.addInput(arguments.getColumns(argumentChannels), Optional.empty())); + AggregationMask mask = AggregationMask.createSelectAll(0); + pagesIterator.forEachRemaining(arguments -> { + mask.reset(arguments.getPositionCount()); + accumulator.addInput(arguments.getColumns(argumentChannels), mask); + }); accumulator.evaluateFinal(blockBuilder); } } @@ -200,27 +205,24 @@ public long getEstimatedSize() } @Override - public void addInput(GroupByIdBlock groupIdsBlock, Page page, Optional mask) + public void setGroupCount(long groupCount) { - groupCount = max(groupCount, groupIdsBlock.getGroupCount()); + this.groupCount = max(this.groupCount, groupCount); + accumulator.setGroupCount(groupCount); + } + + @Override + public void addInput(GroupByIdBlock groupIdsBlock, Page page, AggregationMask mask) + { + if (mask.isSelectNone()) { + return; + } // Add group id block page = page.appendColumn(groupIdsBlock); // mask page - if (mask.isPresent()) { - page = filter(page, mask.orElseThrow()); - } - if (page.getPositionCount() == 0) { - // page was entirely filtered out, but we need to inform the accumulator of the new group count - accumulator.addInput( - new GroupByIdBlock(groupCount, page.getBlock(page.getChannelCount() - 1)), - page.getColumns(argumentChannels), - Optional.empty()); - } - else { - pagesIndex.addPage(page); - } + pagesIndex.addPage(mask.filterPage(page)); } @Override @@ -246,23 +248,14 @@ public void prepareFinal() { pagesIndex.sort(orderByChannels, orderings); Iterator pagesIterator = pagesIndex.getSortedPages(); - pagesIterator.forEachRemaining(page -> accumulator.addInput( - new GroupByIdBlock(groupCount, page.getBlock(page.getChannelCount() - 1)), - page.getColumns(argumentChannels), - Optional.empty())); - } - } - - private static Page filter(Page page, Block mask) - { - int[] ids = new int[mask.getPositionCount()]; - int next = 0; - for (int i = 0; i < page.getPositionCount(); ++i) { - if (BOOLEAN.getBoolean(mask, i)) { - ids[next++] = i; - } + AggregationMask mask = AggregationMask.createSelectAll(0); + pagesIterator.forEachRemaining(page -> { + mask.reset(page.getPositionCount()); + accumulator.addInput( + new GroupByIdBlock(groupCount, page.getBlock(page.getChannelCount() - 1)), + page.getColumns(argumentChannels), + mask); + }); } - - return page.getPositions(ids, 0, next); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java index 8355dd9b5dda..9d704f3f5ae7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java @@ -47,11 +47,11 @@ private ListaggAggregationFunction() {} public static void input( @AggregationState ListaggAggregationState state, @BlockPosition @SqlType("VARCHAR") Block value, + @BlockIndex int position, @SqlType("VARCHAR") Slice separator, @SqlType("BOOLEAN") boolean overflowError, @SqlType("VARCHAR") Slice overflowFiller, - @SqlType("BOOLEAN") boolean showOverflowEntryCount, - @BlockIndex int position) + @SqlType("BOOLEAN") boolean showOverflowEntryCount) { if (state.isEmpty()) { if (overflowFiller.length() > MAX_OVERFLOW_FILLER_LENGTH) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java index 691414610afa..ea0081f3995d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java @@ -39,12 +39,13 @@ private MaxByNAggregationFunction() {} public static void input( @AggregationState({"K", "V"}) MaxByNState state, @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockIndex int valuePosition, @BlockPosition @SqlType("K") Block keyBlock, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockIndex int keyPosition, + @SqlType("BIGINT") long n) { state.initialize(n); - state.add(keyBlock, valueBlock, blockIndex); + state.add(keyBlock, keyPosition, valueBlock, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java index e36d733095e7..737ec8a20e73 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java @@ -39,12 +39,13 @@ private MinByNAggregationFunction() {} public static void input( @AggregationState({"K", "V"}) MinByNState state, @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockIndex int valuePosition, @BlockPosition @SqlType("K") Block keyBlock, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockIndex int keyPosition, + @SqlType("BIGINT") long n) { state.initialize(n); - state.add(keyBlock, valueBlock, blockIndex); + state.add(keyBlock, keyPosition, valueBlock, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java index d5bcd8e6c116..0699f7f75615 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java @@ -29,7 +29,7 @@ public interface MinMaxByNState /** * Adds the value to this state. */ - void add(Block keyBlock, Block valueBlock, int position); + void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition); /** * Merge with the specified state. diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java index 2157943a68ad..5d2c2e0e2c61 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java @@ -81,12 +81,12 @@ public final void initialize(long n) } @Override - public final void add(Block keyBlock, Block valueBlock, int position) + public final void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) { TypedKeyValueHeap typedHeap = getTypedKeyValueHeap(); size -= typedHeap.getEstimatedSize(); - typedHeap.add(keyBlock, valueBlock, position); + typedHeap.add(keyBlock, keyPosition, valueBlock, valuePosition); size += typedHeap.getEstimatedSize(); } @@ -210,9 +210,9 @@ public final void initialize(long n) } @Override - public final void add(Block keyBlock, Block valueBlock, int position) + public final void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) { - typedHeap.add(keyBlock, valueBlock, position); + typedHeap.add(keyBlock, keyPosition, valueBlock, valuePosition); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java index b3c5c697b152..19d1c20f0eed 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java @@ -176,23 +176,23 @@ private void remove() siftDown(); } - public void add(Block keyBlock, Block valueBlock, int position) + public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) { - checkArgument(!keyBlock.isNull(position)); + checkArgument(!keyBlock.isNull(keyPosition)); if (positionCount == capacity) { - if (keyGreaterThanOrEqual(keyBlockBuilder, heapIndex[0], keyBlock, position)) { + if (keyGreaterThanOrEqual(keyBlockBuilder, heapIndex[0], keyBlock, keyPosition)) { return; // and new element is not larger than heap top: do not add } heapIndex[0] = keyBlockBuilder.getPositionCount(); - keyType.appendTo(keyBlock, position, keyBlockBuilder); - valueType.appendTo(valueBlock, position, valueBlockBuilder); + keyType.appendTo(keyBlock, keyPosition, keyBlockBuilder); + valueType.appendTo(valueBlock, valuePosition, valueBlockBuilder); siftDown(); } else { heapIndex[positionCount] = keyBlockBuilder.getPositionCount(); positionCount++; - keyType.appendTo(keyBlock, position, keyBlockBuilder); - valueType.appendTo(valueBlock, position, valueBlockBuilder); + keyType.appendTo(keyBlock, keyPosition, keyBlockBuilder); + valueType.appendTo(valueBlock, valuePosition, valueBlockBuilder); siftUp(); } compactIfNecessary(); @@ -206,7 +206,7 @@ public void addAll(TypedKeyValueHeap otherHeap) public void addAll(Block keysBlock, Block valuesBlock) { for (int i = 0; i < keysBlock.getPositionCount(); i++) { - add(keysBlock, valuesBlock, i); + add(keysBlock, i, valuesBlock, i); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java index ef6b86d6cd89..328cc6c43352 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java @@ -37,8 +37,8 @@ private MaxNAggregationFunction() {} public static void input( @AggregationState("E") MaxNState state, @BlockPosition @SqlType("E") Block block, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockIndex int blockIndex, + @SqlType("BIGINT") long n) { state.initialize(n); state.add(block, blockIndex); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java index aae0eefced43..e19c3b9143ae 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java @@ -37,8 +37,8 @@ private MinNAggregationFunction() {} public static void input( @AggregationState("E") MinNState state, @BlockPosition @SqlType("E") Block block, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockIndex int blockIndex, + @SqlType("BIGINT") long n) { state.initialize(n); state.add(block, blockIndex); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java index 094731abdac3..8a8e64d49f4a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java @@ -33,17 +33,17 @@ public GroupedMultimapAggregationState(Type keyType, Type valueType) } @Override - public void add(Block keyBlock, Block valueBlock, int position) + public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) { prepareAdd(); - appendAtChannel(VALUE_CHANNEL, valueBlock, position); - appendAtChannel(KEY_CHANNEL, keyBlock, position); + appendAtChannel(VALUE_CHANNEL, valueBlock, keyPosition); + appendAtChannel(KEY_CHANNEL, keyBlock, valuePosition); } @Override protected boolean accept(MultimapAggregationStateConsumer consumer, PageBuilder pageBuilder, int currentPosition) { - consumer.accept(pageBuilder.getBlockBuilder(KEY_CHANNEL), pageBuilder.getBlockBuilder(VALUE_CHANNEL), currentPosition); + consumer.accept(pageBuilder.getBlockBuilder(KEY_CHANNEL), currentPosition, pageBuilder.getBlockBuilder(VALUE_CHANNEL), currentPosition); return true; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java index 986d03094142..de417b5d511d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java @@ -56,10 +56,11 @@ private MultimapAggregationFunction() {} public static void input( @AggregationState({"K", "V"}) MultimapAggregationState state, @BlockPosition @SqlType("K") Block key, + @BlockIndex int keyPosition, @NullablePosition @BlockPosition @SqlType("V") Block value, - @BlockIndex int position) + @BlockIndex int valuePosition) { - state.add(key, value, position); + state.add(key, keyPosition, value, valuePosition); } @CombineFunction @@ -97,7 +98,7 @@ public static void output( BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, state.getEntryCount(), expectedValueSize(keyType, 100)); TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctFrom, keyHashCode, state.getEntryCount(), "multimap_agg"); - state.forEach((key, value, keyValueIndex) -> { + state.forEach((key, keyValueIndex, value, valuePosition) -> { // Merge values of the same key into an array if (keySet.add(key, keyValueIndex)) { keyType.appendTo(key, keyValueIndex, distinctKeyBlockBuilder); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java index 97aeb9307962..b3706691925f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java @@ -25,7 +25,7 @@ public interface MultimapAggregationState extends AccumulatorState { - void add(Block keyBlock, Block valueBlock, int position); + void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition); void forEach(MultimapAggregationStateConsumer consumer); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateConsumer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateConsumer.java index c7ec00216307..ed76465c56ae 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateConsumer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateConsumer.java @@ -17,5 +17,5 @@ public interface MultimapAggregationStateConsumer { - void accept(Block keyBlock, Block valueBlock, int position); + void accept(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java index e3acd8549b8b..64eef68bdc08 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateSerializer.java @@ -56,10 +56,10 @@ public void serialize(MultimapAggregationState state, BlockBuilder out) return; } BlockBuilder entryBuilder = out.beginBlockEntry(); - state.forEach((keyBlock, valueBlock, position) -> { + state.forEach((keyBlock, keyPosition, valueBlock, valuePosition) -> { BlockBuilder rowBlockBuilder = entryBuilder.beginBlockEntry(); - valueType.appendTo(valueBlock, position, rowBlockBuilder); - keyType.appendTo(keyBlock, position, rowBlockBuilder); + valueType.appendTo(valueBlock, valuePosition, rowBlockBuilder); + keyType.appendTo(keyBlock, keyPosition, rowBlockBuilder); entryBuilder.closeEntry(); }); out.closeEntry(); @@ -73,7 +73,7 @@ public void deserialize(Block block, int index, MultimapAggregationState state) Block keys = columnarRow.getField(KEY_CHANNEL); Block values = columnarRow.getField(VALUE_CHANNEL); for (int i = 0; i < columnarRow.getPositionCount(); i++) { - state.add(keys, values, i); + state.add(keys, i, values, i); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java index e692f8baecad..cf926193cac8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java @@ -52,17 +52,17 @@ private SingleMultimapAggregationState(Type keyType, Type valueType, BlockBuilde } @Override - public void add(Block key, Block value, int position) + public void add(Block key, int keyPosition, Block value, int valuePosition) { - keyType.appendTo(key, position, keyBlockBuilder); - valueType.appendTo(value, position, valueBlockBuilder); + keyType.appendTo(key, keyPosition, keyBlockBuilder); + valueType.appendTo(value, valuePosition, valueBlockBuilder); } @Override public void forEach(MultimapAggregationStateConsumer consumer) { for (int i = 0; i < keyBlockBuilder.getPositionCount(); i++) { - consumer.accept(keyBlockBuilder, valueBlockBuilder, i); + consumer.accept(keyBlockBuilder, i, valueBlockBuilder, i); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java index 8b63b56712ec..8dd361393483 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java @@ -23,6 +23,7 @@ public class CompilerConfig { private int expressionCacheSize = 10_000; + private boolean specializeAggregationLoops = true; @Min(0) public int getExpressionCacheSize() @@ -37,4 +38,16 @@ public CompilerConfig setExpressionCacheSize(int expressionCacheSize) this.expressionCacheSize = expressionCacheSize; return this; } + + public boolean isSpecializeAggregationLoops() + { + return specializeAggregationLoops; + } + + @Config("compiler.specialized-aggregation-loops") + public CompilerConfig setSpecializeAggregationLoops(boolean specializeAggregationLoops) + { + this.specializeAggregationLoops = specializeAggregationLoops; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 38d28a77926e..6503a5318aee 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -408,6 +408,7 @@ public class LocalExecutionPlanner private final TableExecuteContextManager tableExecuteContextManager; private final ExchangeManagerRegistry exchangeManagerRegistry; private final PositionsAppenderFactory positionsAppenderFactory; + private final boolean specializeAggregationLoops; private final NonEvictableCache accumulatorFactoryCache = buildNonEvictableCache(CacheBuilder.newBuilder() .maximumSize(1000) @@ -441,7 +442,8 @@ public LocalExecutionPlanner( DynamicFilterConfig dynamicFilterConfig, BlockTypeOperators blockTypeOperators, TableExecuteContextManager tableExecuteContextManager, - ExchangeManagerRegistry exchangeManagerRegistry) + ExchangeManagerRegistry exchangeManagerRegistry, + CompilerConfig compilerConfig) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.metadata = plannerContext.getMetadata(); @@ -472,6 +474,7 @@ public LocalExecutionPlanner( this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.positionsAppenderFactory = new PositionsAppenderFactory(blockTypeOperators); + this.specializeAggregationLoops = compilerConfig.isSpecializeAggregationLoops(); } public LocalExecutionPlan plan( @@ -3634,7 +3637,8 @@ private AggregatorFactory buildAggregatorFactory( () -> generateAccumulatorFactory( resolvedFunction.getSignature(), aggregationImplementation, - resolvedFunction.getFunctionNullability())); + resolvedFunction.getFunctionNullability(), + specializeAggregationLoops)); if (aggregation.isDistinct()) { accumulatorFactory = new DistinctAccumulatorFactory( diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 03b22c0afe86..5b9bdba3981c 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -160,6 +160,7 @@ import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; import io.trino.sql.parser.SqlParser; +import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.trino.sql.planner.LogicalPlanner; @@ -952,7 +953,8 @@ private List createDrivers(Session session, Plan plan, OutputFactory out new DynamicFilterConfig(), blockTypeOperators, tableExecuteContextManager, - exchangeManagerRegistry); + exchangeManagerRegistry, + new CompilerConfig()); // plan query LocalExecutionPlan localExecutionPlan = executionPlanner.plan( diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index cc50e5fcb89a..668250b02a43 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -43,6 +43,7 @@ import io.trino.sql.gen.JoinFilterFunctionCompiler; import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; +import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.Partitioning; @@ -174,7 +175,8 @@ public static LocalExecutionPlanner createTestingPlanner() new DynamicFilterConfig(), blockTypeOperators, new TableExecuteContextManager(), - new ExchangeManagerRegistry(new ExchangeHandleResolver())); + new ExchangeManagerRegistry(new ExchangeHandleResolver()), + new CompilerConfig()); } public static TaskInfo updateTask(SqlTask sqlTask, List splitAssignments, OutputBuffers outputBuffers) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/BenchmarkAggregationMaskBuilder.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/BenchmarkAggregationMaskBuilder.java new file mode 100644 index 000000000000..5afb2fbbff97 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/BenchmarkAggregationMaskBuilder.java @@ -0,0 +1,392 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ShortArrayBlock; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import static io.trino.jmh.Benchmarks.benchmark; +import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder; + +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.SECONDS) +@Fork(value = 1, jvmArgsAppend = { + "--add-modules=jdk.incubator.vector", + "-XX:+UnlockDiagnosticVMOptions", +// "-XX:CompileCommand=print,*BenchmarkCore*.*", +// "-XX:PrintAssemblyOptions=intel" +}) +@Warmup(iterations = 10, time = 1000, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 1000, timeUnit = TimeUnit.MILLISECONDS) +public class BenchmarkAggregationMaskBuilder +{ + private final AggregationMaskBuilder rleNoNullsBuilder = new InterpretedAggregationMaskBuilder(0, 3, 6); + private final AggregationMaskBuilder rleNoNullsBuilderCurrent = new CurrentAggregationMaskBuilder(0, 3, 6); + private final AggregationMaskBuilder rleNoNullsBuilderHandCoded = new HandCodedAggregationMaskBuilder(0, 3, 6); + private final AggregationMaskBuilder rleNoNullsBuilderCompiled = compiledMaskBuilder(0, 3, 6); + + private final AggregationMaskBuilder noNullsBuilder = new InterpretedAggregationMaskBuilder(1, 4, 7); + private final AggregationMaskBuilder noNullsBuilderCurrent = new CurrentAggregationMaskBuilder(1, 4, 7); + private final AggregationMaskBuilder noNullsBuilderHandCoded = new HandCodedAggregationMaskBuilder(1, 4, 7); + private final AggregationMaskBuilder noNullsBuilderCompiled = compiledMaskBuilder(1, 4, 7); + + private final AggregationMaskBuilder someNullsBuilder = new InterpretedAggregationMaskBuilder(2, 5, 8); + private final AggregationMaskBuilder someNullsBuilderCurrent = new CurrentAggregationMaskBuilder(2, 5, 8); + private final AggregationMaskBuilder someNullsBuilderHandCoded = new HandCodedAggregationMaskBuilder(2, 5, 8); + private final AggregationMaskBuilder someNullsBuilderCompiled = compiledMaskBuilder(2, 5, 8); + + private final AggregationMaskBuilder oneBlockSomeNullsBuilder = new InterpretedAggregationMaskBuilder(2); + private final AggregationMaskBuilder oneBlockSomeNullsBuilderCurrent = new CurrentAggregationMaskBuilder(2, -1, -1); + private final AggregationMaskBuilder oneBlockSomeNullsBuilderHandCoded = new HandCodedAggregationMaskBuilder(2, -1, -1); + private final AggregationMaskBuilder oneBlockSomeNullsBuilderCompiled = compiledMaskBuilder(2); + + private final AggregationMaskBuilder allBlocksBuilder = new InterpretedAggregationMaskBuilder(0, 1, 2, 3, 4, 5, 6, 7, 8); + private final AggregationMaskBuilder allBlocksBuilderCompiled = compiledMaskBuilder(0, 1, 2, 3, 4, 5, 6, 7, 8); + + private Page arguments; + + @Setup + public void setup() + throws Throwable + { + int positions = 10_000; + + Block shortRleNoNulls = RunLengthEncodedBlock.create(new ShortArrayBlock(1, Optional.empty(), new short[] {42}), positions); + Block shortNoNulls = new ShortArrayBlock(new long[positions].length, Optional.empty(), new short[positions]); + Block shortSomeNulls = new ShortArrayBlock(new long[positions].length, someNulls(positions, 0.3), new short[positions]); + + Block intRleNoNulls = RunLengthEncodedBlock.create(new IntArrayBlock(1, Optional.empty(), new int[] {42}), positions); + Block intNoNulls = new IntArrayBlock(new long[positions].length, Optional.empty(), new int[positions]); + Block intSomeNulls = new IntArrayBlock(new long[positions].length, someNulls(positions, 0.3), new int[positions]); + + Block longRleNoNulls = RunLengthEncodedBlock.create(new LongArrayBlock(1, Optional.empty(), new long[] {42}), positions); + Block longNoNulls = new LongArrayBlock(new long[positions].length, Optional.empty(), new long[positions]); + Block longSomeNulls = new LongArrayBlock(new long[positions].length, someNulls(positions, 0.3), new long[positions]); + + Block rleAllNulls = RunLengthEncodedBlock.create(new ShortArrayBlock(1, Optional.of(new boolean[] {true}), new short[] {42}), positions); + + arguments = new Page( + shortRleNoNulls, + shortNoNulls, + shortSomeNulls, + intRleNoNulls, + intNoNulls, + intSomeNulls, + longRleNoNulls, + longNoNulls, + longSomeNulls, + rleAllNulls); + } + + private static Optional someNulls(int positions, double nullRatio) + { + boolean[] nulls = new boolean[positions]; + for (int i = 0; i < nulls.length; i++) { + // 0.7 ^ 3 = 0.343 + nulls[i] = ThreadLocalRandom.current().nextDouble() < nullRatio; + } + return Optional.of(nulls); + } + + @Benchmark + public Object rleNoNullsBlocksInterpreted() + { + return rleNoNullsBuilder.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object rleNoNullsBlocksCurrent() + { + return rleNoNullsBuilderCurrent.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object rleNoNullsBlocksHandCoded() + { + return rleNoNullsBuilderHandCoded.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object rleNoNullsBlocksCompiled() + { + return rleNoNullsBuilderCompiled.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object noNullsBlocksInterpreted() + { + return noNullsBuilder.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object noNullsBlocksCurrent() + { + return noNullsBuilderCurrent.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object noNullsBlocksHandCoded() + { + return noNullsBuilderHandCoded.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object noNullsBlocksCompiled() + { + return noNullsBuilderCompiled.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object someNullsBlocksInterpreted() + { + return someNullsBuilder.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object someNullsBlocksCurrent() + { + return someNullsBuilderCurrent.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object someNullsBlocksHandCoded() + { + return someNullsBuilderHandCoded.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object someNullsBlocksCompiled() + { + return someNullsBuilderCompiled.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object oneBlockSomeNullsInterpreted() + { + return oneBlockSomeNullsBuilder.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object oneBlockSomeNullsCurrent() + { + return oneBlockSomeNullsBuilderCurrent.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object oneBlockSomeNullsHandCoded() + { + return oneBlockSomeNullsBuilderHandCoded.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object oneBlockSomeNullsCompiled() + { + return oneBlockSomeNullsBuilderCompiled.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object allBlocksInterpreted() + { + return allBlocksBuilder.buildAggregationMask(arguments, Optional.empty()); + } + + @Benchmark + public Object allBlocksCompiled() + { + return allBlocksBuilderCompiled.buildAggregationMask(arguments, Optional.empty()); + } + + public static void main(String[] args) + throws Throwable + { + BenchmarkAggregationMaskBuilder bench = new BenchmarkAggregationMaskBuilder(); + bench.setup(); + bench.rleNoNullsBlocksInterpreted(); + bench.noNullsBlocksInterpreted(); + bench.someNullsBlocksInterpreted(); + bench.allBlocksInterpreted(); + bench.someNullsBlocksCurrent(); + bench.someNullsBlocksHandCoded(); + bench.someNullsBlocksCompiled(); + + benchmark(BenchmarkAggregationMaskBuilder.class).run(); + } + + private static class CurrentAggregationMaskBuilder + implements AggregationMaskBuilder + { + private final int first; + private final int second; + private final int third; + + private final AggregationMask mask = AggregationMask.createSelectAll(0); + + public CurrentAggregationMaskBuilder(int first, int second, int third) + { + this.first = first; + this.second = second; + this.third = third; + } + + @Override + public AggregationMask buildAggregationMask(Page arguments, Optional optionalMaskBlock) + { + int positionCount = arguments.getPositionCount(); + mask.reset(positionCount); + mask.applyMaskBlock(optionalMaskBlock.orElse(null)); + if (first >= 0) { + mask.unselectNullPositions(arguments.getBlock(first)); + } + if (second >= 0) { + mask.unselectNullPositions(arguments.getBlock(second)); + } + if (third >= 0) { + mask.unselectNullPositions(arguments.getBlock(third)); + } + return mask; + } + } + + private static class HandCodedAggregationMaskBuilder + implements AggregationMaskBuilder + { + private final int first; + private final int second; + private final int third; + + public HandCodedAggregationMaskBuilder(int first, int second, int third) + { + this.first = first; + this.second = second; + this.third = third; + } + + private int[] selectedPositions = new int[0]; + + @Override + public AggregationMask buildAggregationMask(Page arguments, Optional optionalMaskBlock) + { + int positionCount = arguments.getPositionCount(); + + // if page is empty, we are done + if (positionCount == 0) { + return AggregationMask.createSelectNone(positionCount); + } + + Block maskBlock = optionalMaskBlock.orElse(null); + boolean hasMaskBlock = maskBlock != null; + boolean maskBlockMayHaveNull = hasMaskBlock && maskBlock.mayHaveNull(); + if (maskBlock instanceof RunLengthEncodedBlock rle) { + Block value = rle.getValue(); + if (!(value == null || + ((!maskBlockMayHaveNull || !value.isNull(0)) && + value.getByte(0, 0) != 0))) { + return AggregationMask.createSelectNone(positionCount); + } + // mask block is always true, so do not evaluate mask block + hasMaskBlock = false; + maskBlockMayHaveNull = false; + } + + Block nonNullArg0 = first < 0 ? null : arguments.getBlock(first); + if (isAlwaysNull(nonNullArg0)) { + return AggregationMask.createSelectNone(positionCount); + } + boolean nonNullArg0MayHaveNull = nonNullArg0 != null && nonNullArg0.mayHaveNull(); + + Block nonNullArg1 = third < 0 ? null : arguments.getBlock(second); + if (isAlwaysNull(nonNullArg1)) { + return AggregationMask.createSelectNone(positionCount); + } + boolean nonNullArg1MayHaveNull = nonNullArg1 != null && nonNullArg1.mayHaveNull(); + + Block nonNullArgN = third < 0 ? null : arguments.getBlock(third); + if (isAlwaysNull(nonNullArgN)) { + return AggregationMask.createSelectNone(positionCount); + } + boolean nonNullArgNMayHaveNull = nonNullArgN != null && nonNullArgN.mayHaveNull(); + + // if there is no mask block, and all non-null arguments do not have nulls, we are done + if (!hasMaskBlock && !nonNullArg0MayHaveNull && !nonNullArg1MayHaveNull && !nonNullArgNMayHaveNull) { + return AggregationMask.createSelectAll(positionCount); + } + + // grow the selection array if necessary + int[] selectedPositions = this.selectedPositions; + if (selectedPositions.length < positionCount) { + selectedPositions = new int[positionCount]; + this.selectedPositions = selectedPositions; + } + + // add all positions that pass the tests + int selectedPositionsIndex = 0; + for (int position = 0; position < positionCount; position++) { + if ((maskBlock == null || ((!maskBlockMayHaveNull || !maskBlock.isNull(position)) && maskBlock.getByte(position, 0) != 0)) && + (!nonNullArg0MayHaveNull || !nonNullArg0.isNull(position)) && + (!nonNullArg1MayHaveNull || !nonNullArg1.isNull(position)) && + (!nonNullArgNMayHaveNull || !nonNullArgN.isNull(position))) { + selectedPositions[selectedPositionsIndex] = position; + selectedPositionsIndex++; + } + } + return AggregationMask.createSelectedPositions(positionCount, selectedPositions, selectedPositionsIndex); + } + } + + private static boolean isAlwaysNull(Block block) + { + if (block instanceof RunLengthEncodedBlock rle) { + return rle.getValue().isNull(0); + } + return false; + } + + private static boolean testMaskBlock(Block block, boolean mayHaveNulls, int position) + { + return block == null || + ((!mayHaveNulls || !block.isNull(position)) && + block.getByte(position, 0) != 0); + } + + private static boolean isNotNull(Block block, boolean mayHaveNulls, int position) + { + return !mayHaveNulls || !block.isNull(position); + } + + private static AggregationMaskBuilder compiledMaskBuilder(int... ints) + { + try { + return generateAggregationMaskBuilder(ints).newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/InterpretedAggregationMaskBuilder.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/InterpretedAggregationMaskBuilder.java new file mode 100644 index 000000000000..1c2f076e0a50 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/InterpretedAggregationMaskBuilder.java @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class InterpretedAggregationMaskBuilder + implements AggregationMaskBuilder +{ + private final List nonNullArguments; + private int[] selectedPositions = new int[0]; + + public InterpretedAggregationMaskBuilder(int... nonNullArguments) + { + this.nonNullArguments = Arrays.stream(nonNullArguments) + .mapToObj(NonNullArgument::new) + .collect(toImmutableList()); + } + + @Override + public AggregationMask buildAggregationMask(Page arguments, Optional optionalMaskBlock) + { + int positionCount = arguments.getPositionCount(); + + // if page is empty, we are done + if (positionCount == 0) { + return AggregationMask.createSelectNone(positionCount); + } + + Block maskBlock = optionalMaskBlock.orElse(null); + boolean hasMaskBlock = maskBlock != null; + boolean maskBlockMayHaveNull = hasMaskBlock && maskBlock.mayHaveNull(); + if (maskBlock instanceof RunLengthEncodedBlock rle) { + if (!testMaskBlock(rle.getValue(), maskBlockMayHaveNull, 0)) { + return AggregationMask.createSelectNone(positionCount); + } + // mask block is always true, so do not evaluate mask block + hasMaskBlock = false; + maskBlockMayHaveNull = false; + } + + for (NonNullArgument nonNullArgument : nonNullArguments) { + nonNullArgument.reset(arguments); + if (nonNullArgument.isAlwaysNull()) { + return AggregationMask.createSelectNone(positionCount); + } + } + + // if there is no mask block, and all non-null arguments do not have nulls, we are done + if (!hasMaskBlock && nonNullArguments.stream().noneMatch(NonNullArgument::mayHaveNull)) { + return AggregationMask.createSelectAll(positionCount); + } + + // grow the selection array if necessary + int[] selectedPositions = this.selectedPositions; + if (selectedPositions.length < positionCount) { + selectedPositions = new int[positionCount]; + this.selectedPositions = selectedPositions; + } + + // add all positions that pass the tests + int selectedPositionsIndex = 0; + for (int i = 0; i < positionCount; i++) { + int position = i; + if (testMaskBlock(maskBlock, maskBlockMayHaveNull, position) && nonNullArguments.stream().allMatch(arg -> arg.isNotNull(position))) { + selectedPositions[selectedPositionsIndex] = position; + selectedPositionsIndex++; + } + } + return AggregationMask.createSelectedPositions(positionCount, selectedPositions, selectedPositionsIndex); + } + + private static boolean testMaskBlock(Block block, boolean mayHaveNulls, int position) + { + if (block == null) { + return true; + } + if (mayHaveNulls && block.isNull(position)) { + return false; + } + return block.getByte(position, 0) != 0; + } + + private static final class NonNullArgument + { + private final int channel; + private Block block; + private boolean mayHaveNull; + + public NonNullArgument(int channel) + { + this.channel = channel; + } + + public void reset(Page arguments) + { + block = arguments.getBlock(channel); + mayHaveNull = block.mayHaveNull(); + } + + public boolean mayHaveNull() + { + return mayHaveNull; + } + + private boolean isAlwaysNull() + { + if (block instanceof RunLengthEncodedBlock rle) { + return rle.getValue().isNull(0); + } + return false; + } + + private boolean isNotNull(int position) + { + return !mayHaveNull || !block.isNull(position); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index a84db0a52e2a..a1364995ffb8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -30,10 +30,12 @@ import io.trino.spi.type.RealType; import io.trino.spi.type.TimestampType; import io.trino.sql.gen.IsolatedClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.lang.invoke.MethodHandle; +import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; @@ -44,16 +46,22 @@ public class TestAccumulatorCompiler { - @Test - public void testAccumulatorCompilerForTypeSpecificObjectParameter() + @DataProvider(name = "specializedLoops") + public static Object[][] hashEnabledValuesProvider() + { + return new Object[][] {{true}, {false}}; + } + + @Test(dataProvider = "specializedLoops") + public void testAccumulatorCompilerForTypeSpecificObjectParameter(boolean specializedLoops) { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); - assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class); + assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class, specializedLoops); } - @Test - public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader() + @Test(dataProvider = "specializedLoops") + public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader(boolean specializedLoops) throws Exception { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; @@ -72,10 +80,10 @@ public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLo assertThat(aggregation.getCanonicalName()).isEqualTo(LongTimestampAggregation.class.getCanonicalName()); assertThat(aggregation).isNotSameAs(LongTimestampAggregation.class); - assertGenerateAccumulator(aggregation, stateInterface); + assertGenerateAccumulator(aggregation, stateInterface, specializedLoops); } - private static void assertGenerateAccumulator(Class aggregation, Class stateInterface) + private static void assertGenerateAccumulator(Class aggregation, Class stateInterface, boolean specializedLoops) { AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(stateInterface); AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateInterface); @@ -94,7 +102,7 @@ private static void assertGenerateAccumulator(Cl FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false)); // test if we can compile aggregation - AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, implementation, functionNullability); + AccumulatorFactory accumulatorFactory = generateAccumulatorFactory(signature, implementation, functionNullability, specializedLoops); assertThat(accumulatorFactory).isNotNull(); assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, implementation, functionNullability)).isNotNull(); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java new file mode 100644 index 000000000000..86be74a8b2f8 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java @@ -0,0 +1,185 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.InOut; +import io.trino.spi.function.SqlType; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static io.trino.operator.aggregation.AggregationLoopBuilder.buildLoop; +import static io.trino.operator.aggregation.AggregationLoopBuilder.toGroupedInputFunction; +import static java.lang.invoke.MethodHandles.explicitCastArguments; +import static java.lang.invoke.MethodHandles.insertArguments; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestAggregationLoopBuilder +{ + private static final MethodHandle INPUT_FUNCTION; + private static final Object LAMBDA_A = "lambda a"; + private static final Object LAMBDA_B = 1234L; + + static { + try { + INPUT_FUNCTION = lookup().findStatic( + TestAggregationLoopBuilder.class, + "input", + methodType(void.class, InvocationList.class, Block.class, int.class, Block.class, int.class, Object.class, Object.class)); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + private MethodHandle loop; + private List keyBlocks; + private List valueBlocks; + + @BeforeClass + public void setUp() + throws ReflectiveOperationException + { + loop = buildLoop(INPUT_FUNCTION, 1, 2, 2); + + Block keyBasic = new IntArrayBlock(5, Optional.empty(), new int[] {10, 11, 12, 13, 14}); + Block keyRleValue = new IntArrayBlock(1, Optional.empty(), new int[] {33}); + Block keyDictionary = new IntArrayBlock(3, Optional.empty(), new int[] {55, 54, 53}); + + keyBlocks = ImmutableList.builder() + .add(new TestParameter(keyBasic, keyBasic, new int[] {0, 1, 2, 3, 4})) + .add(new TestParameter(RunLengthEncodedBlock.create(keyRleValue, 5), keyRleValue, new int[] {0, 0, 0, 0, 0})) + .add(new TestParameter(DictionaryBlock.create(7, keyDictionary, new int[] {9, 9, 2, 1, 0, 1, 2}).getRegion(2, 5), keyDictionary, new int[] {2, 1, 0, 1, 2})) + .build(); + + Block valueBasic = new IntArrayBlock(5, Optional.empty(), new int[] {10, 11, 12, 13, 14}); + Block valueRleValue = new IntArrayBlock(1, Optional.empty(), new int[] {44}); + Block valueDictionary = new IntArrayBlock(3, Optional.empty(), new int[] {66, 65, 64}); + + valueBlocks = ImmutableList.builder() + .add(new TestParameter(valueBasic, valueBasic, new int[] {0, 1, 2, 3, 4})) + .add(new TestParameter(RunLengthEncodedBlock.create(valueRleValue, 5), valueRleValue, new int[] {0, 0, 0, 0, 0})) + .add(new TestParameter(DictionaryBlock.create(7, valueDictionary, new int[] {9, 9, 0, 1, 2, 1, 0}).getRegion(2, 5), valueDictionary, new int[] {0, 1, 2, 1, 0})) + .build(); + } + + @Test + public void testGroupedAdapter() + throws Throwable + { + MethodHandle maxByInput = lookup().findStatic( + MaxByAggregationFunction.class, + "input", + methodType(void.class, MethodHandle.class, InOut.class, InOut.class, Block.class, int.class, Block.class, int.class)); + maxByInput = insertArguments(maxByInput, 0, (MethodHandle) null); + maxByInput = explicitCastArguments(maxByInput, methodType(void.class, AccumulatorState.class, AccumulatorState.class, Block.class, int.class, Block.class, int.class)); + MethodHandle methodHandle = toGroupedInputFunction(maxByInput, 2); + assertThat(methodHandle.type()).isEqualTo(methodType(void.class, AccumulatorState.class, AccumulatorState.class, Block.class, int.class, Block.class, int.class, Block.class, int.class)); + } + + @Test + public void testSelectAll() + throws Throwable + { + AggregationMask mask = AggregationMask.createSelectAll(5); + for (TestParameter keyBlock : keyBlocks) { + for (TestParameter valueBlock : valueBlocks) { + InvocationList invocationList = new InvocationList(); + loop.invokeExact(mask, invocationList, keyBlock.inputBlock(), valueBlock.inputBlock(), LAMBDA_A, LAMBDA_B); + assertThat(invocationList.getInvocations()).isEqualTo(buildExpectedInvocation(keyBlock, valueBlock, mask).getInvocations()); + } + } + } + + @Test + public void testMasked() + throws Throwable + { + AggregationMask mask = AggregationMask.createSelectedPositions(5, new int[] {1, 2, 4}, 3); + for (TestParameter keyBlock : keyBlocks) { + for (TestParameter valueBlock : valueBlocks) { + InvocationList invocationList = new InvocationList(); + loop.invokeExact(mask, invocationList, keyBlock.inputBlock(), valueBlock.inputBlock(), LAMBDA_A, LAMBDA_B); + assertThat(invocationList.getInvocations()).isEqualTo(buildExpectedInvocation(keyBlock, valueBlock, mask).getInvocations()); + } + } + } + + private static InvocationList buildExpectedInvocation(TestParameter keyBlock, TestParameter valueBlock, AggregationMask mask) + { + InvocationList invocationList = new InvocationList(); + int[] keyPositions = keyBlock.invokedPositions(); + int[] valuePositions = valueBlock.invokedPositions(); + if (mask.isSelectAll()) { + for (int position = 0; position < keyPositions.length; position++) { + invocationList.add(keyBlock.invokedBlock(), keyPositions[position], valueBlock.invokedBlock(), valuePositions[position], LAMBDA_A, LAMBDA_B); + } + } + else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < mask.getSelectedPositionCount(); i++) { + int position = selectedPositions[i]; + invocationList.add(keyBlock.invokedBlock(), keyPositions[position], valueBlock.invokedBlock(), valuePositions[position], LAMBDA_A, LAMBDA_B); + } + } + return invocationList; + } + + @SuppressWarnings("UnusedVariable") + private record TestParameter(Block inputBlock, Block invokedBlock, int[] invokedPositions) {} + + public static void input( + @AggregationState InvocationList invocationList, + @BlockPosition @SqlType("K") Block keyBlock, + @BlockIndex int keyPosition, + @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockIndex int valuePosition, + Object lambdaA, + Object lambdaB) + { + invocationList.add(keyBlock, keyPosition, valueBlock, valuePosition, lambdaA, lambdaB); + } + + public static class InvocationList + { + private final List invocations = new ArrayList<>(); + + public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition, Object lambdaA, Object lambdaB) + { + invocations.add(new Invocation(keyBlock, keyPosition, valueBlock, valuePosition, lambdaA, lambdaB)); + } + + public List getInvocations() + { + return ImmutableList.copyOf(invocations); + } + + public record Invocation(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition, Object lambdaA, Object lambdaB) {} + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMask.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMask.java new file mode 100644 index 000000000000..c83b34b576f4 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMask.java @@ -0,0 +1,189 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestAggregationMask +{ + @Test + public void testUnsetNulls() + { + AggregationMask aggregationMask = AggregationMask.createSelectAll(0); + assertAggregationMaskAll(aggregationMask, 0); + + for (int positionCount = 7; positionCount < 10; positionCount++) { + aggregationMask.reset(positionCount); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.unselectNullPositions(new IntArrayBlock(positionCount, Optional.empty(), new int[positionCount])); + assertAggregationMaskAll(aggregationMask, positionCount); + + boolean[] nullFlags = new boolean[positionCount]; + aggregationMask.unselectNullPositions(new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount])); + assertAggregationMaskAll(aggregationMask, positionCount); + + Arrays.fill(nullFlags, true); + nullFlags[1] = false; + nullFlags[3] = false; + nullFlags[5] = false; + aggregationMask.unselectNullPositions(new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount])); + assertAggregationMaskPositions(aggregationMask, positionCount, 1, 3, 5); + + nullFlags[3] = true; + aggregationMask.unselectNullPositions(new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount])); + assertAggregationMaskPositions(aggregationMask, positionCount, 1, 5); + + nullFlags[1] = true; + nullFlags[5] = true; + aggregationMask.unselectNullPositions(new IntArrayBlock(positionCount, Optional.of(nullFlags), new int[positionCount])); + assertAggregationMaskPositions(aggregationMask, positionCount); + + aggregationMask.reset(positionCount); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.unselectNullPositions(RunLengthEncodedBlock.create(new IntArrayBlock(1, Optional.empty(), new int[1]), positionCount)); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.unselectNullPositions(RunLengthEncodedBlock.create(new IntArrayBlock(1, Optional.of(new boolean[] {false}), new int[1]), positionCount)); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.unselectNullPositions(RunLengthEncodedBlock.create(new IntArrayBlock(1, Optional.of(new boolean[] {true}), new int[1]), positionCount)); + assertAggregationMaskPositions(aggregationMask, positionCount); + } + } + + @Test + public void testApplyMask() + { + AggregationMask aggregationMask = AggregationMask.createSelectAll(0); + assertAggregationMaskAll(aggregationMask, 0); + + for (int positionCount = 7; positionCount < 10; positionCount++) { + aggregationMask.reset(positionCount); + assertAggregationMaskAll(aggregationMask, positionCount); + + byte[] mask = new byte[positionCount]; + Arrays.fill(mask, (byte) 1); + + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.empty(), mask)); + assertAggregationMaskAll(aggregationMask, positionCount); + + Arrays.fill(mask, (byte) 0); + mask[1] = 1; + mask[3] = 1; + mask[5] = 1; + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.empty(), mask)); + assertAggregationMaskPositions(aggregationMask, positionCount, 1, 3, 5); + + mask[3] = 0; + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.empty(), mask)); + assertAggregationMaskPositions(aggregationMask, positionCount, 1, 5); + + mask[1] = 0; + mask[5] = 0; + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.empty(), mask)); + assertAggregationMaskPositions(aggregationMask, positionCount); + + aggregationMask.reset(positionCount); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.empty(), new byte[] {1}), positionCount)); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.empty(), new byte[] {0}), positionCount)); + assertAggregationMaskPositions(aggregationMask, positionCount); + } + } + + @Test + public void testApplyMaskNulls() + { + AggregationMask aggregationMask = AggregationMask.createSelectAll(0); + assertAggregationMaskAll(aggregationMask, 0); + + for (int positionCount = 7; positionCount < 10; positionCount++) { + aggregationMask.reset(positionCount); + assertAggregationMaskAll(aggregationMask, positionCount); + + byte[] mask = new byte[positionCount]; + Arrays.fill(mask, (byte) 1); + + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.empty(), mask)); + assertAggregationMaskAll(aggregationMask, positionCount); + + boolean[] nullFlags = new boolean[positionCount]; + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask)); + assertAggregationMaskAll(aggregationMask, positionCount); + + Arrays.fill(nullFlags, true); + nullFlags[1] = false; + nullFlags[3] = false; + nullFlags[5] = false; + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask)); + assertAggregationMaskPositions(aggregationMask, positionCount, 1, 3, 5); + + nullFlags[3] = true; + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask)); + assertAggregationMaskPositions(aggregationMask, positionCount, 1, 5); + + nullFlags[1] = true; + nullFlags[5] = true; + aggregationMask.applyMaskBlock(new ByteArrayBlock(positionCount, Optional.of(nullFlags), mask)); + assertAggregationMaskPositions(aggregationMask, positionCount); + + aggregationMask.reset(positionCount); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.empty(), new byte[] {1}), positionCount)); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.of(new boolean[] {false}), new byte[] {1}), positionCount)); + assertAggregationMaskAll(aggregationMask, positionCount); + + aggregationMask.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.of(new boolean[] {true}), new byte[] {1}), positionCount)); + assertAggregationMaskPositions(aggregationMask, positionCount); + } + } + + private static void assertAggregationMaskAll(AggregationMask aggregationMask, int expectedPositionCount) + { + assertThat(aggregationMask.isSelectAll()).isTrue(); + assertThat(aggregationMask.isSelectNone()).isEqualTo(expectedPositionCount == 0); + assertThat(aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount); + assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositionCount); + assertThatThrownBy(aggregationMask::getSelectedPositions).isInstanceOf(IllegalStateException.class); + } + + private static void assertAggregationMaskPositions(AggregationMask aggregationMask, int expectedPositionCount, int... expectedPositions) + { + assertThat(aggregationMask.isSelectAll()).isFalse(); + assertThat(aggregationMask.isSelectNone()).isEqualTo(expectedPositions.length == 0); + assertThat(aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount); + assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositions.length); + // AssertJ is buggy and does not allow starts with to contain an empty array + if (expectedPositions.length > 0) { + assertThat(aggregationMask.getSelectedPositions()).startsWith(expectedPositions); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java new file mode 100644 index 000000000000..322028b2075f --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java @@ -0,0 +1,242 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ShortArrayBlock; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Optional; +import java.util.function.Supplier; + +import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestAggregationMaskCompiler +{ + @DataProvider + public Object[][] maskBuilderSuppliers() + { + Supplier interpretedMaskBuilderSupplier = () -> new InterpretedAggregationMaskBuilder(1); + Supplier compiledMaskBuilderSupplier = () -> { + try { + return generateAggregationMaskBuilder(1).newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + }; + return new Object[][] {{compiledMaskBuilderSupplier}, {interpretedMaskBuilderSupplier}}; + } + + @Test(dataProvider = "maskBuilderSuppliers") + public void testSupplier(Supplier maskBuilderSupplier) + { + // each builder produced from a supplier could be completely independent + assertThat(maskBuilderSupplier.get()).isNotSameAs(maskBuilderSupplier.get()); + + Page page = buildSingleColumnPage(5); + assertThat(maskBuilderSupplier.get().buildAggregationMask(page, Optional.empty())) + .isNotSameAs(maskBuilderSupplier.get().buildAggregationMask(page, Optional.empty())); + + boolean[] nullFlags = new boolean[5]; + nullFlags[1] = true; + nullFlags[3] = true; + Page pageWithNulls = buildSingleColumnPage(nullFlags); + assertThat(maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty())) + .isNotSameAs(maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty())); + assertThat(maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()) + .isNotSameAs(maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()); + assertThat(maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()) + .isEqualTo(maskBuilderSupplier.get().buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()); + + // a single mask builder is allowed to share arrays across builds + AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get(); + assertThat(maskBuilder.buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()) + .isSameAs(maskBuilder.buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()); + } + + @Test(dataProvider = "maskBuilderSuppliers") + public void testUnsetNulls(Supplier maskBuilderSupplier) + { + AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get(); + AggregationMask aggregationMask = maskBuilder.buildAggregationMask(buildSingleColumnPage(0), Optional.empty()); + assertAggregationMaskAll(aggregationMask, 0); + + for (int positionCount = 7; positionCount < 10; positionCount++) { + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPageRle(positionCount, Optional.of(true)), Optional.empty()), positionCount); + + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.empty()), positionCount); + + boolean[] nullFlags = new boolean[positionCount]; + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(nullFlags), Optional.empty()), positionCount); + + Arrays.fill(nullFlags, true); + nullFlags[1] = false; + nullFlags[3] = false; + nullFlags[5] = false; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(nullFlags), Optional.empty()), positionCount, 1, 3, 5); + + nullFlags[3] = true; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(nullFlags), Optional.empty()), positionCount, 1, 5); + + nullFlags[2] = false; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(nullFlags), Optional.empty()), positionCount, 1, 2, 5); + + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPageRle(positionCount, Optional.empty()), Optional.empty()), positionCount); + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPageRle(positionCount, Optional.of(false)), Optional.empty()), positionCount); + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPageRle(positionCount, Optional.of(true)), Optional.empty()), positionCount); + } + } + + @Test(dataProvider = "maskBuilderSuppliers") + public void testApplyMask(Supplier maskBuilderSupplier) + { + AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get(); + + for (int positionCount = 7; positionCount < 10; positionCount++) { + byte[] mask = new byte[positionCount]; + Arrays.fill(mask, (byte) 1); + + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount); + + Arrays.fill(mask, (byte) 0); + mask[1] = 1; + mask[3] = 1; + mask[5] = 1; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 3, 5); + + mask[3] = 0; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 5); + + mask[2] = 1; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 2, 5); + + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockRle(positionCount, (byte) 1))), positionCount); + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockRle(positionCount, (byte) 0))), positionCount); + } + } + + @Test(dataProvider = "maskBuilderSuppliers") + public void testApplyMaskNulls(Supplier maskBuilderSupplier) + { + AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get(); + + for (int positionCount = 7; positionCount < 10; positionCount++) { + byte[] mask = new byte[positionCount]; + Arrays.fill(mask, (byte) 1); + + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount); + + boolean[] nullFlags = new boolean[positionCount]; + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNulls(nullFlags))), positionCount); + + Arrays.fill(nullFlags, true); + nullFlags[1] = false; + nullFlags[3] = false; + nullFlags[5] = false; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNulls(nullFlags))), positionCount, 1, 3, 5); + + nullFlags[3] = true; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNulls(nullFlags))), positionCount, 1, 5); + + nullFlags[1] = true; + nullFlags[5] = true; + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNulls(nullFlags))), positionCount); + + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNullsRle(positionCount, false))), positionCount); + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNullsRle(positionCount, true))), positionCount); + } + } + + private static Block createMaskBlock(int positionCount, byte[] mask) + { + return new ByteArrayBlock(positionCount, Optional.empty(), mask); + } + + private static Block createMaskBlockRle(int positionCount, byte mask) + { + return RunLengthEncodedBlock.create(createMaskBlock(1, new byte[] {mask}), positionCount); + } + + private static Block createMaskBlockNulls(boolean[] nulls) + { + int positionCount = nulls.length; + byte[] mask = new byte[positionCount]; + Arrays.fill(mask, (byte) 1); + return new ByteArrayBlock(positionCount, Optional.of(nulls), mask); + } + + private static Block createMaskBlockNullsRle(int positionCount, boolean nullValue) + { + return RunLengthEncodedBlock.create(createMaskBlockNulls(new boolean[] {nullValue}), positionCount); + } + + private static Page buildSingleColumnPage(int positionCount) + { + boolean[] ignoredColumnNulls = new boolean[positionCount]; + Arrays.fill(ignoredColumnNulls, true); + return new Page( + new ShortArrayBlock(positionCount, Optional.of(ignoredColumnNulls), new short[positionCount]), + new IntArrayBlock(positionCount, Optional.empty(), new int[positionCount])); + } + + private static Page buildSingleColumnPage(boolean[] nulls) + { + int positionCount = nulls.length; + boolean[] ignoredColumnNulls = new boolean[positionCount]; + Arrays.fill(ignoredColumnNulls, true); + return new Page( + new ShortArrayBlock(positionCount, Optional.of(ignoredColumnNulls), new short[positionCount]), + new IntArrayBlock(positionCount, Optional.of(nulls), new int[positionCount])); + } + + private static Page buildSingleColumnPageRle(int positionCount, Optional nullValue) + { + Optional nulls = nullValue.map(value -> new boolean[] {value}); + boolean[] ignoredColumnNulls = new boolean[positionCount]; + Arrays.fill(ignoredColumnNulls, true); + return new Page( + new ShortArrayBlock(positionCount, Optional.of(ignoredColumnNulls), new short[positionCount]), + RunLengthEncodedBlock.create(new IntArrayBlock(1, nulls, new int[positionCount]), positionCount)); + } + + private static void assertAggregationMaskAll(AggregationMask aggregationMask, int expectedPositionCount) + { + assertThat(aggregationMask.isSelectAll()).isTrue(); + assertThat(aggregationMask.isSelectNone()).isEqualTo(expectedPositionCount == 0); + assertThat(aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount); + assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositionCount); + assertThatThrownBy(aggregationMask::getSelectedPositions).isInstanceOf(IllegalStateException.class); + } + + private static void assertAggregationMaskPositions(AggregationMask aggregationMask, int expectedPositionCount, int... expectedPositions) + { + assertThat(aggregationMask.isSelectAll()).isFalse(); + assertThat(aggregationMask.isSelectNone()).isEqualTo(expectedPositions.length == 0); + assertThat(aggregationMask.getPositionCount()).isEqualTo(expectedPositionCount); + assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(expectedPositions.length); + // AssertJ is buggy and does not allow starts with to contain an empty array + if (expectedPositions.length > 0) { + assertThat(aggregationMask.getSelectedPositions()).startsWith(expectedPositions); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java index 8947bb94ed18..377282743cec 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java @@ -52,7 +52,7 @@ public TestingAggregationFunction(BoundSignature signature, FunctionNullability .collect(toImmutableList()); intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); this.finalType = signature.getReturnType(); - this.factory = generateAccumulatorFactory(signature, aggregationImplementation, functionNullability); + this.factory = generateAccumulatorFactory(signature, aggregationImplementation, functionNullability, true); distinctFactory = new DistinctAccumulatorFactory( factory, parameterTypes, diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java index 1aabfd0ed015..6224f0e7f6fe 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java @@ -56,11 +56,11 @@ public void testInputEmptyState() ListaggAggregationFunction.input( state, value, + 0, separator, false, overflowFiller, - true, - 0); + true); assertFalse(state.isEmpty()); assertEquals(state.getSeparator(), separator); @@ -88,11 +88,11 @@ public void testInputOverflowOverflowFillerTooLong() assertThatThrownBy(() -> ListaggAggregationFunction.input( state, createStringsBlock("value1"), + 0, utf8Slice(","), false, utf8Slice(overflowFillerTooLong), - false, - 0)) + false)) .isInstanceOf(TrinoException.class) .matches(throwable -> ((TrinoException) throwable).getErrorCode() == INVALID_FUNCTION_ARGUMENT.toErrorCode()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java index d2fa9f6e256b..a7c65203edff 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java @@ -28,7 +28,8 @@ public class TestCompilerConfig public void testDefaults() { assertRecordedDefaults(recordDefaults(CompilerConfig.class) - .setExpressionCacheSize(10_000)); + .setExpressionCacheSize(10_000) + .setSpecializeAggregationLoops(true)); } @Test @@ -36,10 +37,12 @@ public void testExplicitPropertyMappings() { Map properties = ImmutableMap.builder() .put("compiler.expression-cache-size", "52") + .put("compiler.specialized-aggregation-loops", "false") .buildOrThrow(); CompilerConfig expected = new CompilerConfig() - .setExpressionCacheSize(52); + .setExpressionCacheSize(52) + .setSpecializeAggregationLoops(false); assertFullMapping(properties, expected); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java index 061715433bc1..b9f913f138d8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java @@ -130,12 +130,12 @@ private DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[ this.isSequentialIds = isSequentialIds; } - int[] getRawIds() + public int[] getRawIds() { return ids; } - int getRawIdsOffset() + public int getRawIdsOffset() { return idsOffset; } diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java index 62ca20d61bae..9f48902a4bf5 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java @@ -39,7 +39,7 @@ public BenchmarkAggregationFunction(ResolvedFunction resolvedFunction, Aggregati BoundSignature signature = resolvedFunction.getSignature(); intermediateType = getOnlyElement(aggregationImplementation.getAccumulatorStateDescriptors()).getSerializer().getSerializedType(); finalType = signature.getReturnType(); - accumulatorFactory = generateAccumulatorFactory(signature, aggregationImplementation, resolvedFunction.getFunctionNullability()); + accumulatorFactory = generateAccumulatorFactory(signature, aggregationImplementation, resolvedFunction.getFunctionNullability(), true); } public AggregatorFactory bind(List inputChannels)