Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.operator.project;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import io.trino.spi.Page;

Expand All @@ -25,15 +26,23 @@
public class InputChannels
{
private final int[] inputChannels;
private final int[] eagerlyLoadedChannels;

public InputChannels(int... inputChannels)
{
this.inputChannels = inputChannels.clone();
this.eagerlyLoadedChannels = new int[0];
}

public InputChannels(List<Integer> inputChannels)
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
{
this(inputChannels, ImmutableList.of());
}

public InputChannels(List<Integer> inputChannels, List<Integer> eagerlyLoadedChannels)
{
this.inputChannels = inputChannels.stream().mapToInt(Integer::intValue).toArray();
this.eagerlyLoadedChannels = eagerlyLoadedChannels.stream().mapToInt(Integer::intValue).toArray();
}

public int size()
Expand All @@ -48,7 +57,7 @@ public List<Integer> getInputChannels()

public Page getInputChannels(Page page)
{
return page.getColumns(inputChannels);
return page.getLoadedPage(inputChannels, eagerlyLoadedChannels);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.relational.Expressions.field;
Expand All @@ -42,26 +45,35 @@ private PageFieldsToInputParametersRewriter() {}
public static Result rewritePageFieldsToInputParameters(RowExpression expression)
{
Visitor visitor = new Visitor();
RowExpression rewrittenProjection = expression.accept(visitor, null);
InputChannels inputChannels = new InputChannels(visitor.getInputChannels());
RowExpression rewrittenProjection = expression.accept(visitor, true);
InputChannels inputChannels = new InputChannels(visitor.getInputChannels(), visitor.getEagerlyLoadedChannels());
return new Result(rewrittenProjection, inputChannels);
}

private static class Visitor
implements RowExpressionVisitor<RowExpression, Void>
implements RowExpressionVisitor<RowExpression, Boolean>
{
private final Map<Integer, Integer> fieldToParameter = new HashMap<>();
private final List<Integer> inputChannels = new ArrayList<>();
Comment thread
sopel39 marked this conversation as resolved.
Outdated
private final Set<Integer> eagerlyLoadedChannels = new HashSet<>();
private int nextParameter;

public List<Integer> getInputChannels()
{
return ImmutableList.copyOf(inputChannels);
}

public List<Integer> getEagerlyLoadedChannels()
{
return ImmutableList.copyOf(eagerlyLoadedChannels);
}

@Override
public RowExpression visitInputReference(InputReferenceExpression reference, Void context)
public RowExpression visitInputReference(InputReferenceExpression reference, Boolean unconditionallyEvaluated)
{
if (unconditionallyEvaluated) {
eagerlyLoadedChannels.add(reference.getField());
}
int parameter = getParameterForField(reference);
return field(parameter, reference.getType());
}
Expand All @@ -75,44 +87,73 @@ private Integer getParameterForField(InputReferenceExpression reference)
}

@Override
public RowExpression visitCall(CallExpression call, Void context)
public RowExpression visitCall(CallExpression call, Boolean unconditionallyEvaluated)
{
boolean containsLambdaExpression = call.getArguments().stream().anyMatch(LambdaDefinitionExpression.class::isInstance);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand why call expression has special handling of lambdas? Shouldn't lambdas be fully covered via visitLambda?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the visitor is traversing lambda expression, the leaf nodes always seem to be ConstantExpression or VariableReferenceExpression. So setting any value for unconditionallyEvaluated in visitLambda seems to be a no-op.
I think lambda expression is supposed to be always part of a function like transform, reduce etc. The InputReferenceExpression for the fields potentially accessed by the lambda are found in the initial arguments of CallExpression.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let this one for @martint to look at.
Could you search for a code that creates CallExpression for lambdas and see what are the arguments? According to antrl lambdas are primary expressions, so I'm not sure where CallExpression comes from

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the leaf nodes always seem to be ConstantExpression or VariableReferenceExpression. So setting any value for unconditionallyEvaluated in visitLambda seems to be a no-op.

You would need to link those references to the corresponding arguments of the call to the higher-order function.

For example, in a hypothetical apply2 function: apply2(a, b, (x, y) -> x AND y), b is conditionally loaded, but a isn't due to the conditional nature of x AND y.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martint in the current code my intent was to fallback to existing behaviour for lambda expression. The motivating scenario for this PR was to improve efficiency for simple filters and projections (e.g. tpch/q01). Can we skip optimising this scenario or would you still consider this a blocker ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can skip it for now. But add a TODO so we know the implementation is not yet complete.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a TODO here now

return new CallExpression(
call.getResolvedFunction(),
call.getArguments().stream()
.map(expression -> expression.accept(this, context))
// Lambda expressions may use only some of their input references, e.g. transform(elements, x -> 1)
// TODO: Currently we fallback to assuming that all the arguments are conditionally evaluated when
// a lambda expression is encountered for the sake of simplicity.
.map(expression -> expression.accept(this, unconditionallyEvaluated && !containsLambdaExpression))
.collect(toImmutableList()));
}

@Override
public RowExpression visitSpecialForm(SpecialForm specialForm, Void context)
public RowExpression visitSpecialForm(SpecialForm specialForm, Boolean unconditionallyEvaluated)
{
return new SpecialForm(
specialForm.getForm(),
specialForm.getType(),
specialForm.getArguments().stream()
.map(expression -> expression.accept(this, context))
.collect(toImmutableList()),
specialForm.getFunctionDependencies());
switch (specialForm.getForm()) {
case IF:
case SWITCH:
case BETWEEN:
case AND:
case OR:
case COALESCE:
List<RowExpression> arguments = specialForm.getArguments();
return new SpecialForm(
specialForm.getForm(),
specialForm.getType(),
IntStream.range(0, arguments.size()).boxed()
// All the arguments after the first one are assumed to be conditionally evaluated
.map(index -> arguments.get(index).accept(this, index == 0 && unconditionallyEvaluated))
.collect(toImmutableList()),
specialForm.getFunctionDependencies());
case BIND:
case IN:
case WHEN:
case IS_NULL:
case NULL_IF:
case DEREFERENCE:
case ROW_CONSTRUCTOR:
return new SpecialForm(
specialForm.getForm(),
specialForm.getType(),
specialForm.getArguments().stream()
.map(expression -> expression.accept(this, unconditionallyEvaluated))
.collect(toImmutableList()),
specialForm.getFunctionDependencies());
}
throw new IllegalArgumentException("Unsupported special form " + specialForm.getForm());
}

@Override
public RowExpression visitConstant(ConstantExpression literal, Void context)
public RowExpression visitConstant(ConstantExpression literal, Boolean unconditionallyEvaluated)
{
return literal;
}

@Override
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Boolean unconditionallyEvaluated)
{
return new LambdaDefinitionExpression(
lambda.getArgumentTypes(),
lambda.getArguments(),
lambda.getBody().accept(this, context));
lambda.getBody().accept(this, unconditionallyEvaluated));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really depends on the function being applied. It's possible that not all of the references in the lambda expression is used.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I've updated visitCall to now assume that all the inputs to lambda expression can be conditionally evaluated (i.e. keep existing behaviour of all blocks being lazily loaded).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really depends on the function being applied. It's possible that not all of the references in the lambda expression is used.

@martint could you put an example here?

Should we just assume that lambda body is unconditionallyEvaluated=false and add a TODO?

}

@Override
public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
public RowExpression visitVariableReference(VariableReferenceExpression reference, Boolean unconditionallyEvaluated)
{
return reference;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.project;

import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.LazyBlock;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SqlToRowExpressionTranslator;
import io.trino.sql.tree.Expression;
import io.trino.testing.TestingSession;
import io.trino.transaction.TransactionId;
import org.testng.annotations.Test;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.block.BlockAssertions.createLongSequenceBlock;
import static io.trino.metadata.MetadataManager.createTestMetadataManager;
import static io.trino.operator.project.PageFieldsToInputParametersRewriter.Result;
import static io.trino.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.ExpressionTestUtils.createExpression;
import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT;
import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer;
import static org.assertj.core.api.Assertions.assertThat;

public class TestPageFieldsToInputParametersRewriter
{
private static final Metadata METADATA = createTestMetadataManager();
private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT);
private static final Session TEST_SESSION = TestingSession.testSessionBuilder()
.setTransactionId(TransactionId.create())
.build();

@Test
public void testEagerLoading()
{
RowExpressionBuilder builder = RowExpressionBuilder.create()
.addSymbol("bigint0", BIGINT)
.addSymbol("bigint1", BIGINT);
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + 5"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("CAST((bigint0 * 10) AS INT)"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE((bigint0 % 2), bigint0)"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 IN (1, 2, 3)"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + 1 = 0"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 BETWEEN 1 AND 10"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint0 ELSE null END"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("CASE bigint0 WHEN 1 THEN 1 ELSE -bigint0 END"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint0 >= 150000, 0, 1)"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint0 >= 150000, bigint0, 0)"), 1);
verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(0, bigint0) + bigint0"), 1);

verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 + (2 * bigint1)"), 2);
verifyEagerlyLoadedColumns(builder.buildExpression("NULLIF(bigint0, bigint1)"), 2);
verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(CEIL(bigint0 / bigint1), 0)"), 2);
verifyEagerlyLoadedColumns(builder.buildExpression("CASE WHEN (bigint0 > bigint1) THEN 1 ELSE 0 END"), 2);
verifyEagerlyLoadedColumns(
builder.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint1 ELSE 0 END"), 2, ImmutableSet.of(0));
verifyEagerlyLoadedColumns(builder.buildExpression("COALESCE(ROUND(bigint0), bigint1)"), 2, ImmutableSet.of(0));
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0 AND bigint1 > 0"), 2, ImmutableSet.of(0));
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 > 0 OR bigint1 > 0"), 2, ImmutableSet.of(0));
verifyEagerlyLoadedColumns(builder.buildExpression("bigint0 BETWEEN 0 AND bigint1"), 2, ImmutableSet.of(0));
verifyEagerlyLoadedColumns(builder.buildExpression("IF(bigint1 >= 150000, 0, bigint0)"), 2, ImmutableSet.of(0));

builder = RowExpressionBuilder.create()
.addSymbol("array_bigint0", new ArrayType(BIGINT))
.addSymbol("array_bigint1", new ArrayType(BIGINT));
verifyEagerlyLoadedColumns(builder.buildExpression("TRANSFORM(array_bigint0, x -> 1)"), 1, ImmutableSet.of());
verifyEagerlyLoadedColumns(builder.buildExpression("TRANSFORM(array_bigint0, x -> 2 * x)"), 1, ImmutableSet.of());
verifyEagerlyLoadedColumns(builder.buildExpression("ZIP_WITH(array_bigint0, array_bigint1, (x, y) -> 2 * x)"), 2, ImmutableSet.of());
}

private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount)
{
verifyEagerlyLoadedColumns(rowExpression, columnCount, IntStream.range(0, columnCount).boxed().collect(toImmutableSet()));
}

private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int columnCount, Set<Integer> eagerlyLoadedChannels)
{
Result result = rewritePageFieldsToInputParameters(rowExpression);
Block[] blocks = new Block[columnCount];
for (int channel = 0; channel < columnCount; channel++) {
blocks[channel] = lazyWrapper(createLongSequenceBlock(0, 100));
}
Page page = result.getInputChannels().getInputChannels(new Page(blocks));
for (int channel = 0; channel < columnCount; channel++) {
assertThat(page.getBlock(channel).isLoaded()).isEqualTo(eagerlyLoadedChannels.contains(channel));
}
}

private static LazyBlock lazyWrapper(Block block)
{
return new LazyBlock(block.getPositionCount(), block::getLoadedBlock);
}

private static class RowExpressionBuilder
{
private final Map<Symbol, Type> symbolTypes = new HashMap<>();
private final Map<Symbol, Integer> sourceLayout = new HashMap<>();
private final List<Type> types = new LinkedList<>();

private static RowExpressionBuilder create()
{
return new RowExpressionBuilder();
}

private RowExpressionBuilder addSymbol(String name, Type type)
{
Symbol symbol = new Symbol(name);
symbolTypes.put(symbol, type);
sourceLayout.put(symbol, types.size());
types.add(type);
return this;
}

private RowExpression buildExpression(String value)
{
Expression expression = createExpression(value, PLANNER_CONTEXT, TypeProvider.copyOf(symbolTypes));

return SqlToRowExpressionTranslator.translate(
expression,
TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression),
sourceLayout,
METADATA,
TEST_SESSION,
true);
}
}
}
18 changes: 18 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/Page.java
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,24 @@ public Page getLoadedPage(int... columns)
return wrapBlocksWithoutCopy(positionCount, blocks);
}

public Page getLoadedPage(int[] columns, int[] eagerlyLoadedColumns)
{
requireNonNull(columns, "columns is null");

for (int column : eagerlyLoadedColumns) {
Comment thread
sopel39 marked this conversation as resolved.
Outdated
this.blocks[column] = this.blocks[column].getLoadedBlock();
}
if (retainedSizeInBytes != -1 && eagerlyLoadedColumns.length > 0) {
updateRetainedSize();
}
Block[] blocks = new Block[columns.length];
for (int i = 0; i < columns.length; i++) {
blocks[i] = this.blocks[columns[i]];
}

return wrapBlocksWithoutCopy(positionCount, blocks);
}

@Override
public String toString()
{
Expand Down
Loading