diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java
index 3a3f9fc2d2654..29d5f766c6c5b 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java
@@ -14,6 +14,9 @@
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
+import com.facebook.presto.common.function.OperatorType;
+import com.facebook.presto.common.type.VarcharType;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
@@ -22,6 +25,8 @@
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
+import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.JoinNode;
@@ -34,18 +39,98 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.stream.IntStream;
import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
+import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
+import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode;
import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject;
+import static com.facebook.presto.sql.planner.PlannerUtils.orNullHashCode;
import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions;
+import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
+import static com.facebook.presto.sql.relational.Expressions.call;
+import static com.facebook.presto.sql.relational.Expressions.callOperator;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
+/**
+ * This optimizer filter the right side of a join with the unique join keys on the left side of the join. When the join key is wide or
+ * there are multiple join keys, we are to do filter on the hash instead of using the keys.
+ * It will convert plan from
+ *
+ * - InnerJoin
+ * leftKey = rightKey
+ * - scan l
+ * - scan r
+ *
+ * into
+ *
+ * - InnerJoin
+ * leftKey = rightKey
+ * - scan l
+ * - semiJoin
+ * r.rightKey in l.leftKey
+ * - scan r
+ * - distinct aggregation
+ * group by leftKey
+ * - scan l
+ *
+ * And for join with varchar type
+ *
+ * - InnerJoin
+ * leftKey (varchar) = rightKey (varchar)
+ * - scan l
+ * - scan r
+ *
+ * into
+ *
+ * - InnerJoin
+ * leftKey (varchar) = rightKey (varchar)
+ * - scan l
+ * - semiJoin
+ * r.rightKeyHash in l.leftKeyHash
+ * - project
+ * r.rightKeyHash = xx_hash64(r.rightKey)
+ * - scan r
+ * - distinct aggregation
+ * group by leftKeyHash
+ * - project
+ * l.leftKeyHash = xx_hash64(l.leftKey)
+ * - scan l
+ *
+ * And for join with multiple keys
+ *
+ * - InnerJoin
+ * leftKey1 = rightKey1 and leftKey2 = rightKey2
+ * - scan l
+ * - scan r
+ *
+ * into
+ *
+ * - InnerJoin
+ * leftKey1 = rightKey1 and leftKey2 = rightKey2
+ * - scan l
+ * - semiJoin
+ * r.rightKeysHash in l.leftKeysHash
+ * - project
+ * r.rightKeysHash = combine_hash(xx_hash64(rightKey1), xx_hash64(rightKey2))
+ * - scan r
+ * - distinct aggregation
+ * group by leftKeysHash
+ * - project
+ * l.leftKeysHash = combine_hash(xx_hash64(leftKey1), xx_hash64(leftKey2))
+ * - scan l
+ *
+ */
+
public class JoinPrefilter
implements PlanOptimizer
{
@@ -73,7 +158,7 @@ public boolean isEnabled(Session session)
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
if (isEnabled(session)) {
- Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator);
+ Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator, metadata.getFunctionAndTypeManager());
PlanNode rewritten = SimplePlanRewriter.rewriteWith(rewriter, plan, null);
return PlanOptimizerResult.optimizerResult(rewritten, rewriter.isPlanChanged());
}
@@ -88,14 +173,16 @@ private static class Rewriter
private final Metadata metadata;
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
+ private final FunctionAndTypeManager functionAndTypeManager;
private boolean planChanged;
- private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
+ private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager)
{
this.session = requireNonNull(session, "session is null");
this.metadata = requireNonNull(metadata, "functionAndTypeManager is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "idAllocator is null");
+ this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
}
@Override
@@ -108,21 +195,37 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context)
PlanNode rewrittenRight = rewriteWith(this, right);
List equiJoinClause = node.getCriteria();
- // We apply this for only left and inner join and the left side of the join is a simple scan and the join is on one key
- if (equiJoinClause.size() == 1 &&
- (node.getType() == LEFT || node.getType() == INNER) &&
- isScanFilterProject(rewrittenLeft)) {
- VariableReferenceExpression leftKey = equiJoinClause.stream().map(x -> x.getLeft()).findFirst().get();
- VariableReferenceExpression rightKey = equiJoinClause.stream().map(x -> x.getRight()).findFirst().get();
+ // We apply this for only left and inner join and the left side of the join is a simple scan
+ if ((node.getType() == LEFT || node.getType() == INNER) && isScanFilterProject(rewrittenLeft) && !node.getCriteria().isEmpty()) {
+ List leftKeyList = equiJoinClause.stream().map(EquiJoinClause::getLeft).collect(toImmutableList());
+ List rightKeyList = equiJoinClause.stream().map(EquiJoinClause::getRight).collect(toImmutableList());
+ checkState(IntStream.range(0, leftKeyList.size()).boxed().allMatch(i -> leftKeyList.get(i).getType().equals(rightKeyList.get(i).getType())));
+
+ boolean hashJoinKey = leftKeyList.size() > 1 || (leftKeyList.get(0).getType().equals(VARCHAR) || leftKeyList.get(0).getType() instanceof VarcharType);
// First create a SELECT DISTINCT leftKey FROM left
Map leftVarMap = new HashMap();
- PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, ImmutableList.of(leftKey), leftVarMap);
- PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, ImmutableList.of(leftVarMap.get(leftKey)), ImmutableList.of());
+ PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, leftKeyList, leftVarMap);
+ ImmutableList.Builder expressionsToProject = ImmutableList.builder();
+ if (hashJoinKey) {
+ RowExpression hashExpression = getVariableHash(leftKeyList);
+ expressionsToProject.add(hashExpression);
+ }
+ else {
+ expressionsToProject.add(leftVarMap.get(leftKeyList.get(0)));
+ }
+ PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, expressionsToProject.build(), ImmutableList.of());
+
+ VariableReferenceExpression rightKeyToFilter = rightKeyList.get(0);
+ if (hashJoinKey) {
+ RowExpression hashExpression = getVariableHash(rightKeyList);
+ rightKeyToFilter = variableAllocator.newVariable(hashExpression);
+ rewrittenRight = addProjections(rewrittenRight, idAllocator, ImmutableMap.of(rightKeyToFilter, hashExpression));
+ }
- // DISTINCT on the leftkey
+ // DISTINCT on the leftkey or hash if wide column
PlanNode filteringSource = new AggregationNode(
- leftKey.getSourceLocation(),
+ node.getLeft().getSourceLocation(),
idAllocator.getNextId(),
projectNode,
ImmutableMap.of(),
@@ -139,12 +242,12 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context)
// Now we add a semijoin as the right side
VariableReferenceExpression semiJoinOutput = variableAllocator.newVariable("semiJoinOutput", BOOLEAN);
SemiJoinNode semiJoinNode = new SemiJoinNode(
- rightKey.getSourceLocation(),
+ node.getRight().getSourceLocation(),
idAllocator.getNextId(),
node.getStatsEquivalentPlanNode(),
rewrittenRight,
filteringSource,
- rightKey,
+ rightKeyToFilter,
filteringSource.getOutputVariables().get(0),
semiJoinOutput,
Optional.empty(),
@@ -153,6 +256,9 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context)
ImmutableMap.of());
rewrittenRight = new FilterNode(semiJoinNode.getSourceLocation(), idAllocator.getNextId(), semiJoinNode, semiJoinOutput);
+ if (rewrittenRight.getOutputVariables().size() > node.getRight().getOutputVariables().size()) {
+ rewrittenRight = restrictOutput(rewrittenRight, idAllocator, node.getRight().getOutputVariables());
+ }
}
if (rewrittenLeft != node.getLeft() || rewrittenRight != node.getRight()) {
@@ -167,5 +273,19 @@ public boolean isPlanChanged()
{
return planChanged;
}
+
+ private RowExpression getVariableHash(List inputVariables)
+ {
+ List hashExpressionList = inputVariables.stream().map(keyVariable ->
+ callOperator(functionAndTypeManager.getFunctionAndTypeResolver(), OperatorType.XX_HASH_64, BIGINT, keyVariable)).collect(toImmutableList());
+ RowExpression hashExpression = hashExpressionList.get(0);
+ if (hashExpressionList.size() > 1) {
+ hashExpression = orNullHashCode(hashExpression);
+ for (int i = 1; i < hashExpressionList.size(); ++i) {
+ hashExpression = call(functionAndTypeManager, "combine_hash", BIGINT, hashExpression, orNullHashCode(hashExpressionList.get(i)));
+ }
+ }
+ return hashExpression;
+ }
}
}
diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
index e043327d2f6f5..588ee48790ad8 100644
--- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
+++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java
@@ -7768,21 +7768,62 @@ public void testLambdaInAggregation()
@Test
public void testJoinPrefilter()
{
- // Orig
- String testQuery = "SELECT 1 from region join nation using(regionkey)";
- MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
- assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
- result = computeActual(testQuery);
- assertEquals(result.getRowCount(), 25);
-
- // With feature
- Session session = Session.builder(getSession())
- .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
- .build();
- result = computeActual(session, "explain(type distributed) " + testQuery);
- assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
- result = computeActual(session, testQuery);
- assertEquals(result.getRowCount(), 25);
+ {
+ // Orig
+ String testQuery = "SELECT 1 from region join nation using(regionkey)";
+ MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
+ assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
+ result = computeActual(testQuery);
+ assertEquals(result.getRowCount(), 25);
+
+ // With feature
+ Session session = Session.builder(getSession())
+ .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
+ .build();
+ result = computeActual(session, "explain(type distributed) " + testQuery);
+ assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
+ result = computeActual(session, testQuery);
+ assertEquals(result.getRowCount(), 25);
+ }
+
+ {
+ // Orig
+ @Language("SQL") String testQuery = "SELECT 1 from region r join nation n on cast(r.regionkey as varchar) = cast(n.regionkey as varchar)";
+ MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
+ assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
+ result = computeActual(testQuery);
+ assertEquals(result.getRowCount(), 25);
+
+ // With feature
+ Session session = Session.builder(getSession())
+ .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
+ .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, String.valueOf(false))
+ .build();
+ result = computeActual(session, "explain(type distributed) " + testQuery);
+ assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
+ assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1);
+ result = computeActual(session, testQuery);
+ assertEquals(result.getRowCount(), 25);
+ }
+
+ {
+ // Orig
+ String testQuery = "SELECT 1 from lineitem l join orders o on l.orderkey = o.orderkey and l.suppkey = o.custkey";
+ MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
+ assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
+ result = computeActual(testQuery);
+ assertEquals(result.getRowCount(), 37);
+
+ // With feature
+ Session session = Session.builder(getSession())
+ .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
+ .build();
+ result = computeActual(session, "explain(type distributed) " + testQuery);
+ assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1);
+ assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1);
+ result = computeActual(session, testQuery);
+ assertEquals(result.getRowCount(), 37);
+ }
}
@Test