Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
Expand All @@ -22,6 +25,8 @@
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.JoinNode;
Expand All @@ -34,18 +39,98 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode;
import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject;
import static com.facebook.presto.sql.planner.PlannerUtils.orNullHashCode;
import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions;
import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.callOperator;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
* This optimizer filter the right side of a join with the unique join keys on the left side of the join. When the join key is wide or
* there are multiple join keys, we are to do filter on the hash instead of using the keys.
* It will convert plan from
* <pre>
* - InnerJoin
* leftKey = rightKey
* - scan l
* - scan r
* </pre>
* into
* <pre>
* - InnerJoin
* leftKey = rightKey
* - scan l
* - semiJoin
* r.rightKey in l.leftKey
* - scan r
* - distinct aggregation
* group by leftKey
* - scan l
* </pre>
* And for join with varchar type
* <pre>
* - InnerJoin
* leftKey (varchar) = rightKey (varchar)
* - scan l
* - scan r
* </pre>
* into
* <pre>
* - InnerJoin
* leftKey (varchar) = rightKey (varchar)
* - scan l
* - semiJoin
* r.rightKeyHash in l.leftKeyHash
* - project
* r.rightKeyHash = xx_hash64(r.rightKey)
* - scan r
* - distinct aggregation
* group by leftKeyHash
* - project
* l.leftKeyHash = xx_hash64(l.leftKey)
* - scan l
* </pre>
* And for join with multiple keys
* <pre>
* - InnerJoin
* leftKey1 = rightKey1 and leftKey2 = rightKey2
* - scan l
* - scan r
* </pre>
* into
* <pre>
* - InnerJoin
* leftKey1 = rightKey1 and leftKey2 = rightKey2
* - scan l
* - semiJoin
* r.rightKeysHash in l.leftKeysHash
* - project
* r.rightKeysHash = combine_hash(xx_hash64(rightKey1), xx_hash64(rightKey2))
* - scan r
* - distinct aggregation
* group by leftKeysHash
* - project
* l.leftKeysHash = combine_hash(xx_hash64(leftKey1), xx_hash64(leftKey2))
* - scan l
* </pre>
*/

public class JoinPrefilter
implements PlanOptimizer
{
Expand Down Expand Up @@ -73,7 +158,7 @@ public boolean isEnabled(Session session)
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
if (isEnabled(session)) {
Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator);
Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator, metadata.getFunctionAndTypeManager());
PlanNode rewritten = SimplePlanRewriter.rewriteWith(rewriter, plan, null);
return PlanOptimizerResult.optimizerResult(rewritten, rewriter.isPlanChanged());
}
Expand All @@ -88,14 +173,16 @@ private static class Rewriter
private final Metadata metadata;
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
private final FunctionAndTypeManager functionAndTypeManager;
private boolean planChanged;

private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager)
{
this.session = requireNonNull(session, "session is null");
this.metadata = requireNonNull(metadata, "functionAndTypeManager is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "idAllocator is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
}

@Override
Expand All @@ -108,21 +195,37 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
PlanNode rewrittenRight = rewriteWith(this, right);
List<EquiJoinClause> equiJoinClause = node.getCriteria();

// We apply this for only left and inner join and the left side of the join is a simple scan and the join is on one key
if (equiJoinClause.size() == 1 &&
(node.getType() == LEFT || node.getType() == INNER) &&
isScanFilterProject(rewrittenLeft)) {
VariableReferenceExpression leftKey = equiJoinClause.stream().map(x -> x.getLeft()).findFirst().get();
VariableReferenceExpression rightKey = equiJoinClause.stream().map(x -> x.getRight()).findFirst().get();
// We apply this for only left and inner join and the left side of the join is a simple scan
if ((node.getType() == LEFT || node.getType() == INNER) && isScanFilterProject(rewrittenLeft) && !node.getCriteria().isEmpty()) {
List<VariableReferenceExpression> leftKeyList = equiJoinClause.stream().map(EquiJoinClause::getLeft).collect(toImmutableList());
List<VariableReferenceExpression> rightKeyList = equiJoinClause.stream().map(EquiJoinClause::getRight).collect(toImmutableList());
checkState(IntStream.range(0, leftKeyList.size()).boxed().allMatch(i -> leftKeyList.get(i).getType().equals(rightKeyList.get(i).getType())));

boolean hashJoinKey = leftKeyList.size() > 1 || (leftKeyList.get(0).getType().equals(VARCHAR) || leftKeyList.get(0).getType() instanceof VarcharType);

// First create a SELECT DISTINCT leftKey FROM left
Map<VariableReferenceExpression, VariableReferenceExpression> leftVarMap = new HashMap();
PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, ImmutableList.of(leftKey), leftVarMap);
PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, ImmutableList.of(leftVarMap.get(leftKey)), ImmutableList.of());
PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, leftKeyList, leftVarMap);
ImmutableList.Builder<RowExpression> expressionsToProject = ImmutableList.builder();
if (hashJoinKey) {
RowExpression hashExpression = getVariableHash(leftKeyList);
expressionsToProject.add(hashExpression);
}
else {
expressionsToProject.add(leftVarMap.get(leftKeyList.get(0)));
}
PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, expressionsToProject.build(), ImmutableList.of());

VariableReferenceExpression rightKeyToFilter = rightKeyList.get(0);
if (hashJoinKey) {
RowExpression hashExpression = getVariableHash(rightKeyList);
rightKeyToFilter = variableAllocator.newVariable(hashExpression);
rewrittenRight = addProjections(rewrittenRight, idAllocator, ImmutableMap.of(rightKeyToFilter, hashExpression));
}

// DISTINCT on the leftkey
// DISTINCT on the leftkey or hash if wide column
PlanNode filteringSource = new AggregationNode(
leftKey.getSourceLocation(),
node.getLeft().getSourceLocation(),
idAllocator.getNextId(),
projectNode,
ImmutableMap.of(),
Expand All @@ -139,12 +242,12 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
// Now we add a semijoin as the right side
VariableReferenceExpression semiJoinOutput = variableAllocator.newVariable("semiJoinOutput", BOOLEAN);
SemiJoinNode semiJoinNode = new SemiJoinNode(
rightKey.getSourceLocation(),
node.getRight().getSourceLocation(),
idAllocator.getNextId(),
node.getStatsEquivalentPlanNode(),
rewrittenRight,
filteringSource,
rightKey,
rightKeyToFilter,
filteringSource.getOutputVariables().get(0),
semiJoinOutput,
Optional.empty(),
Expand All @@ -153,6 +256,9 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
ImmutableMap.of());

rewrittenRight = new FilterNode(semiJoinNode.getSourceLocation(), idAllocator.getNextId(), semiJoinNode, semiJoinOutput);
if (rewrittenRight.getOutputVariables().size() > node.getRight().getOutputVariables().size()) {
rewrittenRight = restrictOutput(rewrittenRight, idAllocator, node.getRight().getOutputVariables());
}
}

if (rewrittenLeft != node.getLeft() || rewrittenRight != node.getRight()) {
Expand All @@ -167,5 +273,19 @@ public boolean isPlanChanged()
{
return planChanged;
}

private RowExpression getVariableHash(List<VariableReferenceExpression> inputVariables)
{
List<CallExpression> hashExpressionList = inputVariables.stream().map(keyVariable ->
callOperator(functionAndTypeManager.getFunctionAndTypeResolver(), OperatorType.XX_HASH_64, BIGINT, keyVariable)).collect(toImmutableList());
RowExpression hashExpression = hashExpressionList.get(0);
if (hashExpressionList.size() > 1) {
hashExpression = orNullHashCode(hashExpression);
for (int i = 1; i < hashExpressionList.size(); ++i) {
hashExpression = call(functionAndTypeManager, "combine_hash", BIGINT, hashExpression, orNullHashCode(hashExpressionList.get(i)));
}
}
return hashExpression;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7768,21 +7768,62 @@ public void testLambdaInAggregation()
@Test
public void testJoinPrefilter()
{
// Orig
String testQuery = "SELECT 1 from region join nation using(regionkey)";
MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
result = computeActual(testQuery);
assertEquals(result.getRowCount(), 25);

// With feature
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
.build();
result = computeActual(session, "explain(type distributed) " + testQuery);
assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
result = computeActual(session, testQuery);
assertEquals(result.getRowCount(), 25);
{
// Orig
String testQuery = "SELECT 1 from region join nation using(regionkey)";
MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
result = computeActual(testQuery);
assertEquals(result.getRowCount(), 25);

// With feature
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
.build();
result = computeActual(session, "explain(type distributed) " + testQuery);
assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
result = computeActual(session, testQuery);
assertEquals(result.getRowCount(), 25);
}

{
// Orig
@Language("SQL") String testQuery = "SELECT 1 from region r join nation n on cast(r.regionkey as varchar) = cast(n.regionkey as varchar)";
MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
result = computeActual(testQuery);
assertEquals(result.getRowCount(), 25);

// With feature
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
.setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, String.valueOf(false))
.build();
result = computeActual(session, "explain(type distributed) " + testQuery);
assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1);
result = computeActual(session, testQuery);
assertEquals(result.getRowCount(), 25);
}

{
// Orig
String testQuery = "SELECT 1 from lineitem l join orders o on l.orderkey = o.orderkey and l.suppkey = o.custkey";
MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
result = computeActual(testQuery);
assertEquals(result.getRowCount(), 37);

// With feature
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
.build();
result = computeActual(session, "explain(type distributed) " + testQuery);
assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1);
result = computeActual(session, testQuery);
assertEquals(result.getRowCount(), 37);
}
}

@Test
Expand Down