From 81b746d8c2cada94efabca9a993cc79a179407a7 Mon Sep 17 00:00:00 2001 From: radek-starburst <94364205+radek-starburst@users.noreply.github.com> Date: Mon, 7 Nov 2022 10:22:00 +0100 Subject: [PATCH] Skip null checks for positions from non-null blocks in aggregation Using io.trino.spi.block.Block.mayHaveNull it can be detected that there are no null positions in the block. Basing on that we skip checking nullability of every postition for such blocks in the aggregation `input` method for performance reason. --- .../aggregation/AccumulatorCompiler.java | 71 ++++++++++++++----- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 24edcb571d79..a86dcc362bac 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -509,6 +509,46 @@ private static BytecodeBlock generateInputForLoop( .putVariable(rowsVariable) .initializeVariable(positionVariable); + /* + It differentiates two cases: (1) when a block may have null positions and (2) when there is no null positions for performance reason. + The expected skeleton of generated code is: + if false or block.mayHaveNull() or ... + for position in 0..rows + if CompilerOperations.testMask(masksBlock, position) and !block0.isNull(position) and ... + this.state_0.input(this.state_0, block0, ..., position) + else + for position in 0..rows + if CompilerOperations.testMask(masksBlock, position) + this.state_0.input(this.state_0, block0, ..., position); + + */ + ForLoop nullCheckLoop = generateInputLoopBody(true, scope, stateField, positionVariable, parameterVariables, lambdaProviderFields, inputFunction, callSiteBinder, grouped, argumentNullable, masksBlock, rowsVariable); + ForLoop noNullCheckLoop = generateInputLoopBody(false, scope, stateField, positionVariable, parameterVariables, lambdaProviderFields, inputFunction, callSiteBinder, grouped, argumentNullable, masksBlock, rowsVariable); + + // prepare mayHaveNull condition + BytecodeExpression mayHaveNullCondition = BytecodeExpressions.constantFalse(); + for (int parameterIndex = 0; parameterIndex < parameterVariables.size(); parameterIndex++) { + if (!argumentNullable.get(parameterIndex)) { + mayHaveNullCondition = BytecodeExpressions.or(mayHaveNullCondition, parameterVariables.get(parameterIndex).invoke("mayHaveNull", boolean.class)); + } + } + + IfStatement mayHaveNullIf = new IfStatement("if(%s)", mayHaveNullCondition).condition(mayHaveNullCondition) + .ifFalse(noNullCheckLoop) + .ifTrue(nullCheckLoop); + + block.append(new IfStatement("if(!maskGuaranteedToFilterAllRows(%s, %s))", rowsVariable.getName(), masksBlock.getName()) + .condition(new BytecodeBlock() + .getVariable(rowsVariable) + .getVariable(masksBlock) + .invokeStatic(AggregationUtils.class, "maskGuaranteedToFilterAllRows", boolean.class, int.class, Block.class)) + .ifFalse(mayHaveNullIf)); + + return block; + } + + private static ForLoop generateInputLoopBody(boolean isNullCheck, Scope scope, List stateField, Variable positionVariable, List parameterVariables, List lambdaProviderFields, MethodHandle inputFunction, CallSiteBinder callSiteBinder, boolean grouped, List argumentNullable, Variable masksBlock, Variable rowsVariable) + { BytecodeNode loopBody = generateInvokeInputFunction( scope, stateField, @@ -520,15 +560,17 @@ private static BytecodeBlock generateInputForLoop( grouped); // Wrap with null checks - for (int parameterIndex = 0; parameterIndex < parameterVariables.size(); parameterIndex++) { - if (!argumentNullable.get(parameterIndex)) { - Variable variableDefinition = parameterVariables.get(parameterIndex); - loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()) - .condition(new BytecodeBlock() - .getVariable(variableDefinition) - .getVariable(positionVariable) - .invokeInterface(Block.class, "isNull", boolean.class, int.class)) - .ifFalse(loopBody); + if (isNullCheck) { + for (int parameterIndex = 0; parameterIndex < parameterVariables.size(); parameterIndex++) { + if (!argumentNullable.get(parameterIndex)) { + Variable variableDefinition = parameterVariables.get(parameterIndex); + loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()) + .condition(new BytecodeBlock() + .getVariable(variableDefinition) + .getVariable(positionVariable) + .invokeInterface(Block.class, "isNull", boolean.class, int.class)) + .ifFalse(loopBody); + } } } @@ -539,7 +581,7 @@ private static BytecodeBlock generateInputForLoop( .invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)) .ifTrue(loopBody); - ForLoop forLoop = new ForLoop() + return new ForLoop() .initialize(new BytecodeBlock().putVariable(positionVariable, 0)) .condition(new BytecodeBlock() .getVariable(positionVariable) @@ -547,15 +589,6 @@ private static BytecodeBlock generateInputForLoop( .invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)) .update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)) .body(loopBody); - - block.append(new IfStatement("if(!maskGuaranteedToFilterAllRows(%s, %s))", rowsVariable.getName(), masksBlock.getName()) - .condition(new BytecodeBlock() - .getVariable(rowsVariable) - .getVariable(masksBlock) - .invokeStatic(AggregationUtils.class, "maskGuaranteedToFilterAllRows", boolean.class, int.class, Block.class)) - .ifFalse(forLoop)); - - return block; } private static BytecodeBlock generateInvokeInputFunction(