From a1b44a5a45f512a6e516391c39699c53a49f9f29 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Thu, 2 Dec 2021 14:39:17 +0530 Subject: [PATCH] Load lazy blocks before using them in page filter/projection Helps to avoid calls to LazyData#getTopLevelBlock in generated page filter and projection methods. PageFieldsToInputParametersRewriter now also records which channels are evaluated unconditionally so that LazyBlock can be loaded for those channels before expression evaluation. --- .../trino/operator/project/InputChannels.java | 11 +- .../PageFieldsToInputParametersRewriter.java | 77 +++++++-- ...stPageFieldsToInputParametersRewriter.java | 155 ++++++++++++++++++ .../src/main/java/io/trino/spi/Page.java | 18 ++ .../src/test/java/io/trino/spi/TestPage.java | 37 +++++ 5 files changed, 279 insertions(+), 19 deletions(-) create mode 100644 core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java diff --git a/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java b/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java index 511398b7732d..4c4ffcd979e3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java @@ -13,6 +13,7 @@ */ package io.trino.operator.project; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.trino.spi.Page; @@ -25,15 +26,23 @@ public class InputChannels { private final int[] inputChannels; + private final int[] eagerlyLoadedChannels; public InputChannels(int... inputChannels) { this.inputChannels = inputChannels.clone(); + this.eagerlyLoadedChannels = new int[0]; } public InputChannels(List inputChannels) + { + this(inputChannels, ImmutableList.of()); + } + + public InputChannels(List inputChannels, List eagerlyLoadedChannels) { this.inputChannels = inputChannels.stream().mapToInt(Integer::intValue).toArray(); + this.eagerlyLoadedChannels = eagerlyLoadedChannels.stream().mapToInt(Integer::intValue).toArray(); } public int size() @@ -48,7 +57,7 @@ public List getInputChannels() public Page getInputChannels(Page page) { - return page.getColumns(inputChannels); + return page.getLoadedPage(inputChannels, eagerlyLoadedChannels); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java b/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java index 4a743a13350c..3a74e4e4230e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java @@ -25,8 +25,11 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.sql.relational.Expressions.field; @@ -42,16 +45,17 @@ private PageFieldsToInputParametersRewriter() {} public static Result rewritePageFieldsToInputParameters(RowExpression expression) { Visitor visitor = new Visitor(); - RowExpression rewrittenProjection = expression.accept(visitor, null); - InputChannels inputChannels = new InputChannels(visitor.getInputChannels()); + RowExpression rewrittenProjection = expression.accept(visitor, true); + InputChannels inputChannels = new InputChannels(visitor.getInputChannels(), visitor.getEagerlyLoadedChannels()); return new Result(rewrittenProjection, inputChannels); } private static class Visitor - implements RowExpressionVisitor + implements RowExpressionVisitor { private final Map fieldToParameter = new HashMap<>(); private final List inputChannels = new ArrayList<>(); + private final Set eagerlyLoadedChannels = new HashSet<>(); private int nextParameter; public List getInputChannels() @@ -59,9 +63,17 @@ public List getInputChannels() return ImmutableList.copyOf(inputChannels); } + public List getEagerlyLoadedChannels() + { + return ImmutableList.copyOf(eagerlyLoadedChannels); + } + @Override - public RowExpression visitInputReference(InputReferenceExpression reference, Void context) + public RowExpression visitInputReference(InputReferenceExpression reference, Boolean unconditionallyEvaluated) { + if (unconditionallyEvaluated) { + eagerlyLoadedChannels.add(reference.getField()); + } int parameter = getParameterForField(reference); return field(parameter, reference.getType()); } @@ -75,44 +87,73 @@ private Integer getParameterForField(InputReferenceExpression reference) } @Override - public RowExpression visitCall(CallExpression call, Void context) + public RowExpression visitCall(CallExpression call, Boolean unconditionallyEvaluated) { + boolean containsLambdaExpression = call.getArguments().stream().anyMatch(LambdaDefinitionExpression.class::isInstance); return new CallExpression( call.getResolvedFunction(), call.getArguments().stream() - .map(expression -> expression.accept(this, context)) + // Lambda expressions may use only some of their input references, e.g. transform(elements, x -> 1) + // TODO: Currently we fallback to assuming that all the arguments are conditionally evaluated when + // a lambda expression is encountered for the sake of simplicity. + .map(expression -> expression.accept(this, unconditionallyEvaluated && !containsLambdaExpression)) .collect(toImmutableList())); } @Override - public RowExpression visitSpecialForm(SpecialForm specialForm, Void context) + public RowExpression visitSpecialForm(SpecialForm specialForm, Boolean unconditionallyEvaluated) { - return new SpecialForm( - specialForm.getForm(), - specialForm.getType(), - specialForm.getArguments().stream() - .map(expression -> expression.accept(this, context)) - .collect(toImmutableList()), - specialForm.getFunctionDependencies()); + switch (specialForm.getForm()) { + case IF: + case SWITCH: + case BETWEEN: + case AND: + case OR: + case COALESCE: + List arguments = specialForm.getArguments(); + return new SpecialForm( + specialForm.getForm(), + specialForm.getType(), + IntStream.range(0, arguments.size()).boxed() + // All the arguments after the first one are assumed to be conditionally evaluated + .map(index -> arguments.get(index).accept(this, index == 0 && unconditionallyEvaluated)) + .collect(toImmutableList()), + specialForm.getFunctionDependencies()); + case BIND: + case IN: + case WHEN: + case IS_NULL: + case NULL_IF: + case DEREFERENCE: + case ROW_CONSTRUCTOR: + return new SpecialForm( + specialForm.getForm(), + specialForm.getType(), + specialForm.getArguments().stream() + .map(expression -> expression.accept(this, unconditionallyEvaluated)) + .collect(toImmutableList()), + specialForm.getFunctionDependencies()); + } + throw new IllegalArgumentException("Unsupported special form " + specialForm.getForm()); } @Override - public RowExpression visitConstant(ConstantExpression literal, Void context) + public RowExpression visitConstant(ConstantExpression literal, Boolean unconditionallyEvaluated) { return literal; } @Override - public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context) + public RowExpression visitLambda(LambdaDefinitionExpression lambda, Boolean unconditionallyEvaluated) { return new LambdaDefinitionExpression( lambda.getArgumentTypes(), lambda.getArguments(), - lambda.getBody().accept(this, context)); + lambda.getBody().accept(this, unconditionallyEvaluated)); } @Override - public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context) + public RowExpression visitVariableReference(VariableReferenceExpression reference, Boolean unconditionallyEvaluated) { return reference; } diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java new file mode 100644 index 000000000000..dea42dc8fad7 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java @@ -0,0 +1,155 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.project; + +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.metadata.Metadata; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.LazyBlock; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.Type; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TypeAnalyzer; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.relational.RowExpression; +import io.trino.sql.relational.SqlToRowExpressionTranslator; +import io.trino.sql.tree.Expression; +import io.trino.testing.TestingSession; +import io.trino.transaction.TransactionId; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.block.BlockAssertions.createLongSequenceBlock; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; +import static io.trino.operator.project.PageFieldsToInputParametersRewriter.Result; +import static io.trino.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.ExpressionTestUtils.createExpression; +import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; +import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestPageFieldsToInputParametersRewriter +{ + private static final Metadata METADATA = createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT); + private static final Session TEST_SESSION = TestingSession.testSessionBuilder() + .setTransactionId(TransactionId.create()) + .build(); + + @Test + public void testEagerLoading() + { + RowExpressionBuilder builder = RowExpressionBuilder.create() + .addSymbol("bigint0", BIGINT) + .addSymbol("bigint1", BIGINT); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + 5"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("CAST((bigint0 * 10) AS INT)"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE((bigint0 % 2), bigint0)"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 IN (1, 2, 3)"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + 1 = 0"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 BETWEEN 1 AND 10"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint0 ELSE null END"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("CASE bigint0 WHEN 1 THEN 1 ELSE -bigint0 END"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint0 >= 150000, 0, 1)"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint0 >= 150000, bigint0, 0)"), 1); + verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(0, bigint0) + bigint0"), 1); + + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + (2 * bigint1)"), 2); + verifyEagerlyLoadedColumns(builder.buildExpression("NULLIF(bigint0, bigint1)"), 2); + verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(CEIL(bigint0 / bigint1), 0)"), 2); + verifyEagerlyLoadedColumns(builder.buildExpression("CASE WHEN (bigint0 > bigint1) THEN 1 ELSE 0 END"), 2); + verifyEagerlyLoadedColumns( + builder.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint1 ELSE 0 END"), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(ROUND(bigint0), bigint1)"), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0 AND bigint1 > 0"), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0 OR bigint1 > 0"), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 BETWEEN 0 AND bigint1"), 2, ImmutableSet.of(0)); + verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint1 >= 150000, 0, bigint0)"), 2, ImmutableSet.of(0)); + + builder = RowExpressionBuilder.create() + .addSymbol("array_bigint0", new ArrayType(BIGINT)) + .addSymbol("array_bigint1", new ArrayType(BIGINT)); + verifyEagerlyLoadedColumns(builder.buildExpression("TRANSFORM(array_bigint0, x -> 1)"), 1, ImmutableSet.of()); + verifyEagerlyLoadedColumns(builder.buildExpression("TRANSFORM(array_bigint0, x -> 2 * x)"), 1, ImmutableSet.of()); + verifyEagerlyLoadedColumns(builder.buildExpression("ZIP_WITH(array_bigint0, array_bigint1, (x, y) -> 2 * x)"), 2, ImmutableSet.of()); + } + + private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount) + { + verifyEagerlyLoadedColumns(rowExpression, columnCount, IntStream.range(0, columnCount).boxed().collect(toImmutableSet())); + } + + private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount, Set eagerlyLoadedChannels) + { + Result result = rewritePageFieldsToInputParameters(rowExpression); + Block[] blocks = new Block[columnCount]; + for (int channel = 0; channel < columnCount; channel++) { + blocks[channel] = lazyWrapper(createLongSequenceBlock(0, 100)); + } + Page page = result.getInputChannels().getInputChannels(new Page(blocks)); + for (int channel = 0; channel < columnCount; channel++) { + assertThat(page.getBlock(channel).isLoaded()).isEqualTo(eagerlyLoadedChannels.contains(channel)); + } + } + + private static LazyBlock lazyWrapper(Block block) + { + return new LazyBlock(block.getPositionCount(), block::getLoadedBlock); + } + + private static class RowExpressionBuilder + { + private final Map symbolTypes = new HashMap<>(); + private final Map sourceLayout = new HashMap<>(); + private final List types = new LinkedList<>(); + + private static RowExpressionBuilder create() + { + return new RowExpressionBuilder(); + } + + private RowExpressionBuilder addSymbol(String name, Type type) + { + Symbol symbol = new Symbol(name); + symbolTypes.put(symbol, type); + sourceLayout.put(symbol, types.size()); + types.add(type); + return this; + } + + private RowExpression buildExpression(String value) + { + Expression expression = createExpression(value, PLANNER_CONTEXT, TypeProvider.copyOf(symbolTypes)); + + return SqlToRowExpressionTranslator.translate( + expression, + TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), + sourceLayout, + METADATA, + TEST_SESSION, + true); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/Page.java b/core/trino-spi/src/main/java/io/trino/spi/Page.java index f612bafd5f53..1a8873d4b097 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/Page.java +++ b/core/trino-spi/src/main/java/io/trino/spi/Page.java @@ -315,6 +315,24 @@ public Page getLoadedPage(int... columns) return wrapBlocksWithoutCopy(positionCount, blocks); } + public Page getLoadedPage(int[] columns, int[] eagerlyLoadedColumns) + { + requireNonNull(columns, "columns is null"); + + for (int column : eagerlyLoadedColumns) { + this.blocks[column] = this.blocks[column].getLoadedBlock(); + } + if (retainedSizeInBytes != -1 && eagerlyLoadedColumns.length > 0) { + updateRetainedSize(); + } + Block[] blocks = new Block[columns.length]; + for (int i = 0; i < columns.length; i++) { + blocks[i] = this.blocks[columns[i]]; + } + + return wrapBlocksWithoutCopy(positionCount, blocks); + } + @Override public String toString() { diff --git a/core/trino-spi/src/test/java/io/trino/spi/TestPage.java b/core/trino-spi/src/test/java/io/trino/spi/TestPage.java index 840777bef28a..737ff0fa72b9 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/TestPage.java +++ b/core/trino-spi/src/test/java/io/trino/spi/TestPage.java @@ -19,10 +19,12 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.DictionaryId; +import io.trino.spi.block.LazyBlock; import org.testng.annotations.Test; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verifyNotNull; +import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.spi.block.DictionaryId.randomDictionaryId; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; @@ -134,6 +136,41 @@ public void testGetPositions() } } + @Test + public void testGetLoadedPage() + { + int entries = 10; + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, entries); + for (int i = 0; i < entries; i++) { + BIGINT.writeLong(blockBuilder, i); + } + Block block = blockBuilder.build(); + + LazyBlock lazyBlock = lazyWrapper(block); + Page page = new Page(lazyBlock); + long lazyPageRetainedSize = Page.INSTANCE_SIZE + sizeOf(new Block[] {block}) + lazyBlock.getRetainedSizeInBytes(); + assertEquals(page.getRetainedSizeInBytes(), lazyPageRetainedSize); + Page loadedPage = page.getLoadedPage(); + // Retained size of page remains the same + assertEquals(page.getRetainedSizeInBytes(), lazyPageRetainedSize); + long loadedPageRetainedSize = Page.INSTANCE_SIZE + sizeOf(new Block[] {block}) + block.getRetainedSizeInBytes(); + // Retained size of loaded page depends on the loaded block + assertEquals(loadedPage.getRetainedSizeInBytes(), loadedPageRetainedSize); + + lazyBlock = lazyWrapper(block); + page = new Page(lazyBlock); + assertEquals(page.getRetainedSizeInBytes(), lazyPageRetainedSize); + loadedPage = page.getLoadedPage(new int[] {0}, new int[] {0}); + // Retained size of page is updated based on loaded block + assertEquals(page.getRetainedSizeInBytes(), loadedPageRetainedSize); + assertEquals(loadedPage.getRetainedSizeInBytes(), loadedPageRetainedSize); + } + + private static LazyBlock lazyWrapper(Block block) + { + return new LazyBlock(block.getPositionCount(), block::getLoadedBlock); + } + private static Slice[] createExpectedValues(int positionCount) { Slice[] expectedValues = new Slice[positionCount];