diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 24edcb571d79..a86dcc362bac 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -509,6 +509,46 @@ private static BytecodeBlock generateInputForLoop( .putVariable(rowsVariable) .initializeVariable(positionVariable); + /* + It differentiates two cases: (1) when a block may have null positions and (2) when there is no null positions for performance reason. + The expected skeleton of generated code is: + if false or block.mayHaveNull() or ... + for position in 0..rows + if CompilerOperations.testMask(masksBlock, position) and !block0.isNull(position) and ... + this.state_0.input(this.state_0, block0, ..., position) + else + for position in 0..rows + if CompilerOperations.testMask(masksBlock, position) + this.state_0.input(this.state_0, block0, ..., position); + + */ + ForLoop nullCheckLoop = generateInputLoopBody(true, scope, stateField, positionVariable, parameterVariables, lambdaProviderFields, inputFunction, callSiteBinder, grouped, argumentNullable, masksBlock, rowsVariable); + ForLoop noNullCheckLoop = generateInputLoopBody(false, scope, stateField, positionVariable, parameterVariables, lambdaProviderFields, inputFunction, callSiteBinder, grouped, argumentNullable, masksBlock, rowsVariable); + + // prepare mayHaveNull condition + BytecodeExpression mayHaveNullCondition = BytecodeExpressions.constantFalse(); + for (int parameterIndex = 0; parameterIndex < parameterVariables.size(); parameterIndex++) { + if (!argumentNullable.get(parameterIndex)) { + mayHaveNullCondition = BytecodeExpressions.or(mayHaveNullCondition, parameterVariables.get(parameterIndex).invoke("mayHaveNull", boolean.class)); + } + } + + IfStatement mayHaveNullIf = new IfStatement("if(%s)", mayHaveNullCondition).condition(mayHaveNullCondition) + .ifFalse(noNullCheckLoop) + .ifTrue(nullCheckLoop); + + block.append(new IfStatement("if(!maskGuaranteedToFilterAllRows(%s, %s))", rowsVariable.getName(), masksBlock.getName()) + .condition(new BytecodeBlock() + .getVariable(rowsVariable) + .getVariable(masksBlock) + .invokeStatic(AggregationUtils.class, "maskGuaranteedToFilterAllRows", boolean.class, int.class, Block.class)) + .ifFalse(mayHaveNullIf)); + + return block; + } + + private static ForLoop generateInputLoopBody(boolean isNullCheck, Scope scope, List stateField, Variable positionVariable, List parameterVariables, List lambdaProviderFields, MethodHandle inputFunction, CallSiteBinder callSiteBinder, boolean grouped, List argumentNullable, Variable masksBlock, Variable rowsVariable) + { BytecodeNode loopBody = generateInvokeInputFunction( scope, stateField, @@ -520,15 +560,17 @@ private static BytecodeBlock generateInputForLoop( grouped); // Wrap with null checks - for (int parameterIndex = 0; parameterIndex < parameterVariables.size(); parameterIndex++) { - if (!argumentNullable.get(parameterIndex)) { - Variable variableDefinition = parameterVariables.get(parameterIndex); - loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()) - .condition(new BytecodeBlock() - .getVariable(variableDefinition) - .getVariable(positionVariable) - .invokeInterface(Block.class, "isNull", boolean.class, int.class)) - .ifFalse(loopBody); + if (isNullCheck) { + for (int parameterIndex = 0; parameterIndex < parameterVariables.size(); parameterIndex++) { + if (!argumentNullable.get(parameterIndex)) { + Variable variableDefinition = parameterVariables.get(parameterIndex); + loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()) + .condition(new BytecodeBlock() + .getVariable(variableDefinition) + .getVariable(positionVariable) + .invokeInterface(Block.class, "isNull", boolean.class, int.class)) + .ifFalse(loopBody); + } } } @@ -539,7 +581,7 @@ private static BytecodeBlock generateInputForLoop( .invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)) .ifTrue(loopBody); - ForLoop forLoop = new ForLoop() + return new ForLoop() .initialize(new BytecodeBlock().putVariable(positionVariable, 0)) .condition(new BytecodeBlock() .getVariable(positionVariable) @@ -547,15 +589,6 @@ private static BytecodeBlock generateInputForLoop( .invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)) .update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)) .body(loopBody); - - block.append(new IfStatement("if(!maskGuaranteedToFilterAllRows(%s, %s))", rowsVariable.getName(), masksBlock.getName()) - .condition(new BytecodeBlock() - .getVariable(rowsVariable) - .getVariable(masksBlock) - .invokeStatic(AggregationUtils.class, "maskGuaranteedToFilterAllRows", boolean.class, int.class, Block.class)) - .ifFalse(forLoop)); - - return block; } private static BytecodeBlock generateInvokeInputFunction(