diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java index 3a3f9fc2d2654..29d5f766c6c5b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java @@ -14,6 +14,9 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; @@ -22,6 +25,8 @@ import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -34,18 +39,98 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.IntStream; import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.JoinType.INNER; import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.sql.planner.PlannerUtils.addProjections; import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode; import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject; +import static com.facebook.presto.sql.planner.PlannerUtils.orNullHashCode; import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions; +import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.callOperator; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; +/** + * This optimizer filter the right side of a join with the unique join keys on the left side of the join. When the join key is wide or + * there are multiple join keys, we are to do filter on the hash instead of using the keys. + * It will convert plan from + *
+ *     - InnerJoin
+ *          leftKey = rightKey
+ *          - scan l
+ *          - scan r
+ * 
+ * into + *
+ *     - InnerJoin
+ *          leftKey = rightKey
+ *          - scan l
+ *          - semiJoin
+ *              r.rightKey in l.leftKey
+ *              - scan r
+ *              - distinct aggregation
+ *                  group by leftKey
+ *                  - scan l
+ * 
+ * And for join with varchar type + *
+ *     - InnerJoin
+ *          leftKey (varchar) = rightKey (varchar)
+ *          - scan l
+ *          - scan r
+ * 
+ * into + *
+ *     - InnerJoin
+ *          leftKey (varchar) = rightKey (varchar)
+ *          - scan l
+ *          - semiJoin
+ *              r.rightKeyHash in l.leftKeyHash
+ *              - project
+ *                  r.rightKeyHash = xx_hash64(r.rightKey)
+ *                  - scan r
+ *              - distinct aggregation
+ *                  group by leftKeyHash
+ *                  - project
+ *                      l.leftKeyHash = xx_hash64(l.leftKey)
+ *                      - scan l
+ * 
+ * And for join with multiple keys + *
+ *     - InnerJoin
+ *          leftKey1 = rightKey1 and leftKey2 = rightKey2
+ *          - scan l
+ *          - scan r
+ * 
+ * into + *
+ *     - InnerJoin
+ *          leftKey1 = rightKey1 and leftKey2 = rightKey2
+ *          - scan l
+ *          - semiJoin
+ *              r.rightKeysHash in l.leftKeysHash
+ *              - project
+ *                  r.rightKeysHash = combine_hash(xx_hash64(rightKey1), xx_hash64(rightKey2))
+ *                  - scan r
+ *              - distinct aggregation
+ *                  group by leftKeysHash
+ *                  - project
+ *                      l.leftKeysHash = combine_hash(xx_hash64(leftKey1), xx_hash64(leftKey2))
+ *                      - scan l
+ * 
+ */ + public class JoinPrefilter implements PlanOptimizer { @@ -73,7 +158,7 @@ public boolean isEnabled(Session session) public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { if (isEnabled(session)) { - Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator); + Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator, metadata.getFunctionAndTypeManager()); PlanNode rewritten = SimplePlanRewriter.rewriteWith(rewriter, plan, null); return PlanOptimizerResult.optimizerResult(rewritten, rewriter.isPlanChanged()); } @@ -88,14 +173,16 @@ private static class Rewriter private final Metadata metadata; private final PlanNodeIdAllocator idAllocator; private final VariableAllocator variableAllocator; + private final FunctionAndTypeManager functionAndTypeManager; private boolean planChanged; - private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) + private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "functionAndTypeManager is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.variableAllocator = requireNonNull(variableAllocator, "idAllocator is null"); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); } @Override @@ -108,21 +195,37 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) PlanNode rewrittenRight = rewriteWith(this, right); List equiJoinClause = node.getCriteria(); - // We apply this for only left and inner join and the left side of the join is a simple scan and the join is on one key - if (equiJoinClause.size() == 1 && - (node.getType() == LEFT || node.getType() == INNER) && - isScanFilterProject(rewrittenLeft)) { - VariableReferenceExpression leftKey = equiJoinClause.stream().map(x -> x.getLeft()).findFirst().get(); - VariableReferenceExpression rightKey = equiJoinClause.stream().map(x -> x.getRight()).findFirst().get(); + // We apply this for only left and inner join and the left side of the join is a simple scan + if ((node.getType() == LEFT || node.getType() == INNER) && isScanFilterProject(rewrittenLeft) && !node.getCriteria().isEmpty()) { + List leftKeyList = equiJoinClause.stream().map(EquiJoinClause::getLeft).collect(toImmutableList()); + List rightKeyList = equiJoinClause.stream().map(EquiJoinClause::getRight).collect(toImmutableList()); + checkState(IntStream.range(0, leftKeyList.size()).boxed().allMatch(i -> leftKeyList.get(i).getType().equals(rightKeyList.get(i).getType()))); + + boolean hashJoinKey = leftKeyList.size() > 1 || (leftKeyList.get(0).getType().equals(VARCHAR) || leftKeyList.get(0).getType() instanceof VarcharType); // First create a SELECT DISTINCT leftKey FROM left Map leftVarMap = new HashMap(); - PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, ImmutableList.of(leftKey), leftVarMap); - PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, ImmutableList.of(leftVarMap.get(leftKey)), ImmutableList.of()); + PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, leftKeyList, leftVarMap); + ImmutableList.Builder expressionsToProject = ImmutableList.builder(); + if (hashJoinKey) { + RowExpression hashExpression = getVariableHash(leftKeyList); + expressionsToProject.add(hashExpression); + } + else { + expressionsToProject.add(leftVarMap.get(leftKeyList.get(0))); + } + PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, expressionsToProject.build(), ImmutableList.of()); + + VariableReferenceExpression rightKeyToFilter = rightKeyList.get(0); + if (hashJoinKey) { + RowExpression hashExpression = getVariableHash(rightKeyList); + rightKeyToFilter = variableAllocator.newVariable(hashExpression); + rewrittenRight = addProjections(rewrittenRight, idAllocator, ImmutableMap.of(rightKeyToFilter, hashExpression)); + } - // DISTINCT on the leftkey + // DISTINCT on the leftkey or hash if wide column PlanNode filteringSource = new AggregationNode( - leftKey.getSourceLocation(), + node.getLeft().getSourceLocation(), idAllocator.getNextId(), projectNode, ImmutableMap.of(), @@ -139,12 +242,12 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // Now we add a semijoin as the right side VariableReferenceExpression semiJoinOutput = variableAllocator.newVariable("semiJoinOutput", BOOLEAN); SemiJoinNode semiJoinNode = new SemiJoinNode( - rightKey.getSourceLocation(), + node.getRight().getSourceLocation(), idAllocator.getNextId(), node.getStatsEquivalentPlanNode(), rewrittenRight, filteringSource, - rightKey, + rightKeyToFilter, filteringSource.getOutputVariables().get(0), semiJoinOutput, Optional.empty(), @@ -153,6 +256,9 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) ImmutableMap.of()); rewrittenRight = new FilterNode(semiJoinNode.getSourceLocation(), idAllocator.getNextId(), semiJoinNode, semiJoinOutput); + if (rewrittenRight.getOutputVariables().size() > node.getRight().getOutputVariables().size()) { + rewrittenRight = restrictOutput(rewrittenRight, idAllocator, node.getRight().getOutputVariables()); + } } if (rewrittenLeft != node.getLeft() || rewrittenRight != node.getRight()) { @@ -167,5 +273,19 @@ public boolean isPlanChanged() { return planChanged; } + + private RowExpression getVariableHash(List inputVariables) + { + List hashExpressionList = inputVariables.stream().map(keyVariable -> + callOperator(functionAndTypeManager.getFunctionAndTypeResolver(), OperatorType.XX_HASH_64, BIGINT, keyVariable)).collect(toImmutableList()); + RowExpression hashExpression = hashExpressionList.get(0); + if (hashExpressionList.size() > 1) { + hashExpression = orNullHashCode(hashExpression); + for (int i = 1; i < hashExpressionList.size(); ++i) { + hashExpression = call(functionAndTypeManager, "combine_hash", BIGINT, hashExpression, orNullHashCode(hashExpressionList.get(i))); + } + } + return hashExpression; + } } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index e043327d2f6f5..588ee48790ad8 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -7768,21 +7768,62 @@ public void testLambdaInAggregation() @Test public void testJoinPrefilter() { - // Orig - String testQuery = "SELECT 1 from region join nation using(regionkey)"; - MaterializedResult result = computeActual("explain(type distributed) " + testQuery); - assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); - result = computeActual(testQuery); - assertEquals(result.getRowCount(), 25); - - // With feature - Session session = Session.builder(getSession()) - .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) - .build(); - result = computeActual(session, "explain(type distributed) " + testQuery); - assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); - result = computeActual(session, testQuery); - assertEquals(result.getRowCount(), 25); + { + // Orig + String testQuery = "SELECT 1 from region join nation using(regionkey)"; + MaterializedResult result = computeActual("explain(type distributed) " + testQuery); + assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); + result = computeActual(testQuery); + assertEquals(result.getRowCount(), 25); + + // With feature + Session session = Session.builder(getSession()) + .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) + .build(); + result = computeActual(session, "explain(type distributed) " + testQuery); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); + result = computeActual(session, testQuery); + assertEquals(result.getRowCount(), 25); + } + + { + // Orig + @Language("SQL") String testQuery = "SELECT 1 from region r join nation n on cast(r.regionkey as varchar) = cast(n.regionkey as varchar)"; + MaterializedResult result = computeActual("explain(type distributed) " + testQuery); + assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); + result = computeActual(testQuery); + assertEquals(result.getRowCount(), 25); + + // With feature + Session session = Session.builder(getSession()) + .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) + .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, String.valueOf(false)) + .build(); + result = computeActual(session, "explain(type distributed) " + testQuery); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1); + result = computeActual(session, testQuery); + assertEquals(result.getRowCount(), 25); + } + + { + // Orig + String testQuery = "SELECT 1 from lineitem l join orders o on l.orderkey = o.orderkey and l.suppkey = o.custkey"; + MaterializedResult result = computeActual("explain(type distributed) " + testQuery); + assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); + result = computeActual(testQuery); + assertEquals(result.getRowCount(), 37); + + // With feature + Session session = Session.builder(getSession()) + .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) + .build(); + result = computeActual(session, "explain(type distributed) " + testQuery); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1); + result = computeActual(session, testQuery); + assertEquals(result.getRowCount(), 37); + } } @Test