-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Load lazy blocks before using them in page filter/projection #10322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,8 +25,11 @@ | |
|
|
||
| import java.util.ArrayList; | ||
| import java.util.HashMap; | ||
| import java.util.HashSet; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Set; | ||
| import java.util.stream.IntStream; | ||
|
|
||
| import static com.google.common.collect.ImmutableList.toImmutableList; | ||
| import static io.trino.sql.relational.Expressions.field; | ||
|
|
@@ -42,26 +45,35 @@ private PageFieldsToInputParametersRewriter() {} | |
| public static Result rewritePageFieldsToInputParameters(RowExpression expression) | ||
| { | ||
| Visitor visitor = new Visitor(); | ||
| RowExpression rewrittenProjection = expression.accept(visitor, null); | ||
| InputChannels inputChannels = new InputChannels(visitor.getInputChannels()); | ||
| RowExpression rewrittenProjection = expression.accept(visitor, true); | ||
| InputChannels inputChannels = new InputChannels(visitor.getInputChannels(), visitor.getEagerlyLoadedChannels()); | ||
| return new Result(rewrittenProjection, inputChannels); | ||
| } | ||
|
|
||
| private static class Visitor | ||
| implements RowExpressionVisitor<RowExpression, Void> | ||
| implements RowExpressionVisitor<RowExpression, Boolean> | ||
| { | ||
| private final Map<Integer, Integer> fieldToParameter = new HashMap<>(); | ||
| private final List<Integer> inputChannels = new ArrayList<>(); | ||
|
sopel39 marked this conversation as resolved.
Outdated
|
||
| private final Set<Integer> eagerlyLoadedChannels = new HashSet<>(); | ||
| private int nextParameter; | ||
|
|
||
| public List<Integer> getInputChannels() | ||
| { | ||
| return ImmutableList.copyOf(inputChannels); | ||
| } | ||
|
|
||
| public List<Integer> getEagerlyLoadedChannels() | ||
| { | ||
| return ImmutableList.copyOf(eagerlyLoadedChannels); | ||
| } | ||
|
|
||
| @Override | ||
| public RowExpression visitInputReference(InputReferenceExpression reference, Void context) | ||
| public RowExpression visitInputReference(InputReferenceExpression reference, Boolean unconditionallyEvaluated) | ||
| { | ||
| if (unconditionallyEvaluated) { | ||
| eagerlyLoadedChannels.add(reference.getField()); | ||
| } | ||
| int parameter = getParameterForField(reference); | ||
| return field(parameter, reference.getType()); | ||
| } | ||
|
|
@@ -75,44 +87,73 @@ private Integer getParameterForField(InputReferenceExpression reference) | |
| } | ||
|
|
||
| @Override | ||
| public RowExpression visitCall(CallExpression call, Void context) | ||
| public RowExpression visitCall(CallExpression call, Boolean unconditionallyEvaluated) | ||
| { | ||
| boolean containsLambdaExpression = call.getArguments().stream().anyMatch(LambdaDefinitionExpression.class::isInstance); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't fully understand why call expression has special handling of lambdas? Shouldn't lambdas be fully covered via
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the visitor is traversing lambda expression, the leaf nodes always seem to be ConstantExpression or VariableReferenceExpression. So setting any value for
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will let this one for @martint to look at.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You would need to link those references to the corresponding arguments of the call to the higher-order function. For example, in a hypothetical
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @martint in the current code my intent was to fallback to existing behaviour for lambda expression. The motivating scenario for this PR was to improve efficiency for simple filters and projections (e.g. tpch/q01). Can we skip optimising this scenario or would you still consider this a blocker ?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we can skip it for now. But add a TODO so we know the implementation is not yet complete.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a TODO here now |
||
| return new CallExpression( | ||
| call.getResolvedFunction(), | ||
| call.getArguments().stream() | ||
| .map(expression -> expression.accept(this, context)) | ||
| // Lambda expressions may use only some of their input references, e.g. transform(elements, x -> 1) | ||
| // TODO: Currently we fallback to assuming that all the arguments are conditionally evaluated when | ||
| // a lambda expression is encountered for the sake of simplicity. | ||
| .map(expression -> expression.accept(this, unconditionallyEvaluated && !containsLambdaExpression)) | ||
| .collect(toImmutableList())); | ||
| } | ||
|
|
||
| @Override | ||
| public RowExpression visitSpecialForm(SpecialForm specialForm, Void context) | ||
| public RowExpression visitSpecialForm(SpecialForm specialForm, Boolean unconditionallyEvaluated) | ||
| { | ||
| return new SpecialForm( | ||
| specialForm.getForm(), | ||
| specialForm.getType(), | ||
| specialForm.getArguments().stream() | ||
| .map(expression -> expression.accept(this, context)) | ||
| .collect(toImmutableList()), | ||
| specialForm.getFunctionDependencies()); | ||
| switch (specialForm.getForm()) { | ||
| case IF: | ||
| case SWITCH: | ||
| case BETWEEN: | ||
| case AND: | ||
| case OR: | ||
| case COALESCE: | ||
| List<RowExpression> arguments = specialForm.getArguments(); | ||
| return new SpecialForm( | ||
| specialForm.getForm(), | ||
| specialForm.getType(), | ||
| IntStream.range(0, arguments.size()).boxed() | ||
| // All the arguments after the first one are assumed to be conditionally evaluated | ||
| .map(index -> arguments.get(index).accept(this, index == 0 && unconditionallyEvaluated)) | ||
| .collect(toImmutableList()), | ||
| specialForm.getFunctionDependencies()); | ||
| case BIND: | ||
| case IN: | ||
| case WHEN: | ||
| case IS_NULL: | ||
| case NULL_IF: | ||
| case DEREFERENCE: | ||
| case ROW_CONSTRUCTOR: | ||
| return new SpecialForm( | ||
| specialForm.getForm(), | ||
| specialForm.getType(), | ||
| specialForm.getArguments().stream() | ||
| .map(expression -> expression.accept(this, unconditionallyEvaluated)) | ||
| .collect(toImmutableList()), | ||
| specialForm.getFunctionDependencies()); | ||
| } | ||
| throw new IllegalArgumentException("Unsupported special form " + specialForm.getForm()); | ||
| } | ||
|
|
||
| @Override | ||
| public RowExpression visitConstant(ConstantExpression literal, Void context) | ||
| public RowExpression visitConstant(ConstantExpression literal, Boolean unconditionallyEvaluated) | ||
| { | ||
| return literal; | ||
| } | ||
|
|
||
| @Override | ||
| public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context) | ||
| public RowExpression visitLambda(LambdaDefinitionExpression lambda, Boolean unconditionallyEvaluated) | ||
| { | ||
| return new LambdaDefinitionExpression( | ||
| lambda.getArgumentTypes(), | ||
| lambda.getArguments(), | ||
| lambda.getBody().accept(this, context)); | ||
| lambda.getBody().accept(this, unconditionallyEvaluated)); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This really depends on the function being applied. It's possible that not all of the references in the lambda expression is used.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, I've updated
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@martint could you put an example here? Should we just assume that lambda body is |
||
| } | ||
|
|
||
| @Override | ||
| public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context) | ||
| public RowExpression visitVariableReference(VariableReferenceExpression reference, Boolean unconditionallyEvaluated) | ||
| { | ||
| return reference; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| /* | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package io.trino.operator.project; | ||
|
|
||
| import com.google.common.collect.ImmutableSet; | ||
| import io.trino.Session; | ||
| import io.trino.metadata.Metadata; | ||
| import io.trino.spi.Page; | ||
| import io.trino.spi.block.Block; | ||
| import io.trino.spi.block.LazyBlock; | ||
| import io.trino.spi.type.ArrayType; | ||
| import io.trino.spi.type.Type; | ||
| import io.trino.sql.planner.Symbol; | ||
| import io.trino.sql.planner.TypeAnalyzer; | ||
| import io.trino.sql.planner.TypeProvider; | ||
| import io.trino.sql.relational.RowExpression; | ||
| import io.trino.sql.relational.SqlToRowExpressionTranslator; | ||
| import io.trino.sql.tree.Expression; | ||
| import io.trino.testing.TestingSession; | ||
| import io.trino.transaction.TransactionId; | ||
| import org.testng.annotations.Test; | ||
|
|
||
| import java.util.HashMap; | ||
| import java.util.LinkedList; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Set; | ||
| import java.util.stream.IntStream; | ||
|
|
||
| import static com.google.common.collect.ImmutableSet.toImmutableSet; | ||
| import static io.trino.block.BlockAssertions.createLongSequenceBlock; | ||
| import static io.trino.metadata.MetadataManager.createTestMetadataManager; | ||
| import static io.trino.operator.project.PageFieldsToInputParametersRewriter.Result; | ||
| import static io.trino.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters; | ||
| import static io.trino.spi.type.BigintType.BIGINT; | ||
| import static io.trino.sql.ExpressionTestUtils.createExpression; | ||
| import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; | ||
| import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; | ||
| import static org.assertj.core.api.Assertions.assertThat; | ||
|
|
||
| public class TestPageFieldsToInputParametersRewriter | ||
| { | ||
| private static final Metadata METADATA = createTestMetadataManager(); | ||
| private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT); | ||
| private static final Session TEST_SESSION = TestingSession.testSessionBuilder() | ||
| .setTransactionId(TransactionId.create()) | ||
| .build(); | ||
|
|
||
| @Test | ||
| public void testEagerLoading() | ||
| { | ||
| RowExpressionBuilder builder = RowExpressionBuilder.create() | ||
| .addSymbol("bigint0", BIGINT) | ||
| .addSymbol("bigint1", BIGINT); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + 5"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("CAST((bigint0 * 10) AS INT)"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE((bigint0 % 2), bigint0)"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 IN (1, 2, 3)"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + 1 = 0"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 BETWEEN 1 AND 10"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint0 ELSE null END"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("CASE bigint0 WHEN 1 THEN 1 ELSE -bigint0 END"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint0 >= 150000, 0, 1)"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint0 >= 150000, bigint0, 0)"), 1); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(0, bigint0) + bigint0"), 1); | ||
|
|
||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + (2 * bigint1)"), 2); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("NULLIF(bigint0, bigint1)"), 2); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(CEIL(bigint0 / bigint1), 0)"), 2); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("CASE WHEN (bigint0 > bigint1) THEN 1 ELSE 0 END"), 2); | ||
| verifyEagerlyLoadedColumns( | ||
| builder.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint1 ELSE 0 END"), 2, ImmutableSet.of(0)); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(ROUND(bigint0), bigint1)"), 2, ImmutableSet.of(0)); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0 AND bigint1 > 0"), 2, ImmutableSet.of(0)); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0 OR bigint1 > 0"), 2, ImmutableSet.of(0)); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 BETWEEN 0 AND bigint1"), 2, ImmutableSet.of(0)); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint1 >= 150000, 0, bigint0)"), 2, ImmutableSet.of(0)); | ||
|
|
||
| builder = RowExpressionBuilder.create() | ||
| .addSymbol("array_bigint0", new ArrayType(BIGINT)) | ||
| .addSymbol("array_bigint1", new ArrayType(BIGINT)); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("TRANSFORM(array_bigint0, x -> 1)"), 1, ImmutableSet.of()); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("TRANSFORM(array_bigint0, x -> 2 * x)"), 1, ImmutableSet.of()); | ||
| verifyEagerlyLoadedColumns(builder.buildExpression("ZIP_WITH(array_bigint0, array_bigint1, (x, y) -> 2 * x)"), 2, ImmutableSet.of()); | ||
| } | ||
|
|
||
| private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount) | ||
| { | ||
| verifyEagerlyLoadedColumns(rowExpression, columnCount, IntStream.range(0, columnCount).boxed().collect(toImmutableSet())); | ||
| } | ||
|
|
||
| private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount, Set<Integer> eagerlyLoadedChannels) | ||
| { | ||
| Result result = rewritePageFieldsToInputParameters(rowExpression); | ||
| Block[] blocks = new Block[columnCount]; | ||
| for (int channel = 0; channel < columnCount; channel++) { | ||
| blocks[channel] = lazyWrapper(createLongSequenceBlock(0, 100)); | ||
| } | ||
| Page page = result.getInputChannels().getInputChannels(new Page(blocks)); | ||
| for (int channel = 0; channel < columnCount; channel++) { | ||
| assertThat(page.getBlock(channel).isLoaded()).isEqualTo(eagerlyLoadedChannels.contains(channel)); | ||
| } | ||
| } | ||
|
|
||
| private static LazyBlock lazyWrapper(Block block) | ||
| { | ||
| return new LazyBlock(block.getPositionCount(), block::getLoadedBlock); | ||
| } | ||
|
|
||
| private static class RowExpressionBuilder | ||
| { | ||
| private final Map<Symbol, Type> symbolTypes = new HashMap<>(); | ||
| private final Map<Symbol, Integer> sourceLayout = new HashMap<>(); | ||
| private final List<Type> types = new LinkedList<>(); | ||
|
|
||
| private static RowExpressionBuilder create() | ||
| { | ||
| return new RowExpressionBuilder(); | ||
| } | ||
|
|
||
| private RowExpressionBuilder addSymbol(String name, Type type) | ||
| { | ||
| Symbol symbol = new Symbol(name); | ||
| symbolTypes.put(symbol, type); | ||
| sourceLayout.put(symbol, types.size()); | ||
| types.add(type); | ||
| return this; | ||
| } | ||
|
|
||
| private RowExpression buildExpression(String value) | ||
| { | ||
| Expression expression = createExpression(value, PLANNER_CONTEXT, TypeProvider.copyOf(symbolTypes)); | ||
|
|
||
| return SqlToRowExpressionTranslator.translate( | ||
| expression, | ||
| TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), | ||
| sourceLayout, | ||
| METADATA, | ||
| TEST_SESSION, | ||
| true); | ||
| } | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.