diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt index 44f85122d99f7..8418d929db9d0 100644 --- a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt @@ -2,6 +2,22 @@ remote exchange (GATHER, SINGLE, []) local exchange (GATHER, UNKNOWN, []) remote exchange (REPARTITION, ROUND_ROBIN, []) join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, [subtract_400]) + join (INNER, PARTITIONED): + final aggregation over (d_week_seq_232) + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [d_week_seq_232]) + partial aggregation over (d_week_seq_232) + join (INNER, REPLICATED): + remote exchange (REPARTITION, ROUND_ROBIN, []) + scan web_sales + scan catalog_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [d_week_seq_316]) + scan date_dim join (INNER, PARTITIONED): final aggregation over (d_week_seq) local exchange (GATHER, SINGLE, []) @@ -17,20 +33,3 @@ remote exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, [d_week_seq_83]) scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [subtract]) - join (INNER, PARTITIONED): - final aggregation over (d_week_seq_232) - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [d_week_seq_232]) - partial aggregation over (d_week_seq_232) - join (INNER, REPLICATED): - remote exchange (REPARTITION, ROUND_ROBIN, []) - scan web_sales - scan catalog_sales - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [d_week_seq_316]) - scan date_dim diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt index 23c58414ed691..e9fc6a9c0d2ec 100644 --- a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt @@ -1,7 +1,7 @@ local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, [d_week_seq, s_store_id]) + remote exchange (REPARTITION, HASH, [d_week_seq, d_week_seq_267, s_store_id]) join (INNER, REPLICATED): join (INNER, REPLICATED): final aggregation over (d_week_seq, ss_store_sk) @@ -20,7 +20,7 @@ local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan store local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [s_store_id_235, subtract]) + remote exchange (REPARTITION, HASH, [d_week_seq_147, d_week_seq_63, s_store_id_235]) join (INNER, REPLICATED): join (INNER, REPLICATED): final aggregation over (d_week_seq_147, ss_store_sk_127) diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java index c00c366f3b237..03cf107b15645 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.SpecialFormExpression.Form; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableSet; import java.util.ArrayDeque; import java.util.ArrayList; @@ -522,6 +523,17 @@ public ConvertNormalFormVisitorContext childContext() } } + private static class VariableReferenceBuilderVisitor + extends DefaultRowExpressionTraversalVisitor> + { + @Override + public Void visitVariableReference(VariableReferenceExpression variable, ImmutableSet.Builder builder) + { + builder.add(variable); + return null; + } + } + private class ConvertNormalFormVisitor implements RowExpressionVisitor { diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 07f43eab6ef08..9a07f16fb98b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -287,6 +287,7 @@ public final class SystemSessionProperties public static final String PULL_EXPRESSION_FROM_LAMBDA_ENABLED = "pull_expression_from_lambda_enabled"; public static final String REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION = "rewrite_constant_array_contains_to_in_expression"; public static final String INFER_INEQUALITY_PREDICATES = "infer_inequality_predicates"; + public static final String HANDLE_COMPLEX_EQUI_JOINS = "handle_complex_equi_joins"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "simplified_expression_evaluation_enabled"; @@ -1673,6 +1674,11 @@ public SystemSessionProperties( INFER_INEQUALITY_PREDICATES, "Infer nonequality predicates for joins", featuresConfig.getInferInequalityPredicates(), + false), + booleanProperty( + HANDLE_COMPLEX_EQUI_JOINS, + "Handle complex equi-join conditions to open up join space for join reordering", + featuresConfig.getHandleComplexEquiJoins(), false)); } @@ -2821,4 +2827,9 @@ public static boolean shouldInferInequalityPredicates(Session session) { return session.getSystemProperty(INFER_INEQUALITY_PREDICATES, Boolean.class); } + + public static boolean shouldHandleComplexEquiJoins(Session session) + { + return session.getSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index d2df4224b4969..e6b1575e22f2b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -278,6 +278,7 @@ public class FeaturesConfig private boolean rewriteConstantArrayContainsToIn; private boolean preProcessMetadataCalls; + private boolean handleComplexEquiJoins = true; public enum PartitioningPrecisionStrategy { @@ -2751,4 +2752,17 @@ public FeaturesConfig setRewriteConstantArrayContainsToInEnabled(boolean rewrite this.rewriteConstantArrayContainsToIn = rewriteConstantArrayContainsToIn; return this; } + + public boolean getHandleComplexEquiJoins() + { + return handleComplexEquiJoins; + } + + @Config("optimizer.handle-complex-equi-joins") + @ConfigDescription("Handle complex equi-join conditions to open up join space for join reordering") + public FeaturesConfig setHandleComplexEquiJoins(boolean handleComplexEquiJoins) + { + this.handleComplexEquiJoins = handleComplexEquiJoins; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java index fa86d5abf501c..aea88ca5b81a0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -24,10 +24,12 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; @@ -35,7 +37,6 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.CanonicalJoinNode; import com.facebook.presto.sql.planner.EqualityInference; -import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -53,6 +54,7 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -66,12 +68,15 @@ import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; import static com.facebook.presto.SystemSessionProperties.getMaxReorderedJoins; -import static com.facebook.presto.common.function.OperatorType.EQUAL; +import static com.facebook.presto.SystemSessionProperties.shouldHandleComplexEquiJoins; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.expressions.LogicalRowExpressions.and; import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts; +import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.AUTOMATIC; import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; +import static com.facebook.presto.sql.planner.PlannerUtils.addProjections; +import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique; import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isBelowMaxBroadcastSize; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT; @@ -135,7 +140,8 @@ public boolean isEnabled(Session session) @Override public Result apply(JoinNode joinNode, Captures captures, Context context) { - MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, context.getLookup(), getMaxReorderedJoins(context.getSession()), functionResolution, determinismEvaluator); + MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, context.getLookup(), getMaxReorderedJoins(context.getSession()), shouldHandleComplexEquiJoins(context.getSession()), + functionResolution, determinismEvaluator); JoinEnumerator joinEnumerator = new JoinEnumerator( costComparator, multiJoinNode.getFilter(), @@ -143,11 +149,19 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) determinismEvaluator, functionResolution, metadata); + JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputVariables()); + if (!result.getPlanNode().isPresent()) { return Result.empty(); } - return Result.ofPlanNode(result.getPlanNode().get()); + + PlanNode transformedPlan = result.getPlanNode().get(); + if (multiJoinNode.getAssignments().isPresent()) { + transformedPlan = addProjections(transformedPlan, context.getIdAllocator(), multiJoinNode.getAssignments().get().getMap()); + } + + return Result.ofPlanNode(transformedPlan); } @VisibleForTesting @@ -166,6 +180,7 @@ static class JoinEnumerator private final Context context; private final Map, JoinEnumerationResult> memo = new HashMap<>(); + private final FunctionResolution functionResolution; @VisibleForTesting JoinEnumerator(CostComparator costComparator, RowExpression filter, Context context, DeterminismEvaluator determinismEvaluator, FunctionResolution functionResolution, Metadata metadata) @@ -181,6 +196,7 @@ static class JoinEnumerator this.metadata = requireNonNull(metadata, "metadata is null"); this.allFilterInference = createEqualityInference(metadata, filter); this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, functionResolution, metadata.getFunctionAndTypeManager()); + this.functionResolution = functionResolution; } private JoinEnumerationResult chooseJoinOrder(LinkedHashSet sources, List outputVariables) @@ -255,28 +271,31 @@ JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet private JoinEnumerationResult createJoin(LinkedHashSet leftSources, LinkedHashSet rightSources, List outputVariables) { - Set leftVariables = leftSources.stream() + HashSet leftVariables = leftSources.stream() .flatMap(node -> node.getOutputVariables().stream()) - .collect(toImmutableSet()); - Set rightVariables = rightSources.stream() + .collect(toCollection(HashSet::new)); + HashSet rightVariables = rightSources.stream() .flatMap(node -> node.getOutputVariables().stream()) - .collect(toImmutableSet()); + .collect(toCollection(HashSet::new)); List joinPredicates = getJoinPredicates(leftVariables, rightVariables); - List joinConditions = joinPredicates.stream() - .filter(JoinEnumerator::isJoinEqualityCondition) - .map(predicate -> toEquiJoinClause((CallExpression) predicate, leftVariables, context.getVariableAllocator())) - .collect(toImmutableList()); - if (joinConditions.isEmpty()) { + + VariableAllocator variableAllocator = context.getVariableAllocator(); + JoinCondition joinConditions = extractJoinConditions(joinPredicates, leftVariables, rightVariables, variableAllocator); + List joinClauses = joinConditions.getJoinClauses(); + List joinFilters = joinConditions.getJoinFilters(); + + //Update the left & right variable sets with any new variables generated + leftVariables.addAll(joinConditions.getNewLeftAssignments().keySet()); + rightVariables.addAll(joinConditions.getNewRightAssignments().keySet()); + + if (joinClauses.isEmpty()) { return INFINITE_COST_RESULT; } - List joinFilters = joinPredicates.stream() - .filter(predicate -> !isJoinEqualityCondition(predicate)) - .collect(toImmutableList()); Set requiredJoinVariables = ImmutableSet.builder() .addAll(outputVariables) - .addAll(VariablesExtractor.extractUnique(joinPredicates)) + .addAll(extractUnique(joinPredicates)) .build(); JoinEnumerationResult leftResult = getJoinSource( @@ -292,6 +311,13 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li } PlanNode left = leftResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present")); + if (!joinConditions.getNewLeftAssignments().isEmpty()) { + ImmutableMap.Builder assignments = ImmutableMap.builder(); + left.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable)); + assignments.putAll(joinConditions.getNewLeftAssignments()); + + left = addProjections(left, idAllocator, assignments.build()); + } JoinEnumerationResult rightResult = getJoinSource( rightSources, @@ -306,6 +332,13 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li } PlanNode right = rightResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present")); + if (!joinConditions.getNewRightAssignments().isEmpty()) { + ImmutableMap.Builder assignments = ImmutableMap.builder(); + right.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable)); + assignments.putAll(joinConditions.getNewRightAssignments()); + + right = addProjections(right, idAllocator, assignments.build()); + } // sort output variables so that the left input variables are first List sortedOutputVariables = Stream.concat(left.getOutputVariables().stream(), right.getOutputVariables().stream()) @@ -318,7 +351,7 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li INNER, left, right, - joinConditions, + joinClauses, sortedOutputVariables, joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)), Optional.empty(), @@ -374,22 +407,103 @@ private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List< return chooseJoinOrder(nodes, outputVariables); } - private static boolean isJoinEqualityCondition(RowExpression expression) + @VisibleForTesting + JoinCondition extractJoinConditions(List joinPredicates, + Set leftVariables, + Set rightVariables, + VariableAllocator variableAllocator) { - return expression instanceof CallExpression - && ((CallExpression) expression).getDisplayName().equals(EQUAL.getFunctionName().getObjectName()) - && ((CallExpression) expression).getArguments().size() == 2 - && ((CallExpression) expression).getArguments().get(0) instanceof VariableReferenceExpression - && ((CallExpression) expression).getArguments().get(1) instanceof VariableReferenceExpression; + ImmutableMap.Builder newLeftAssignments = ImmutableMap.builder(); + ImmutableMap.Builder newRightAssignments = ImmutableMap.builder(); + + ImmutableList.Builder joinClauses = ImmutableList.builder(); + ImmutableList.Builder joinFilters = ImmutableList.builder(); + + for (RowExpression predicate : joinPredicates) { + if (predicate instanceof CallExpression + && functionResolution.isEqualFunction(((CallExpression) predicate).getFunctionHandle()) + && ((CallExpression) predicate).getArguments().size() == 2) { + RowExpression argument0 = ((CallExpression) predicate).getArguments().get(0); + RowExpression argument1 = ((CallExpression) predicate).getArguments().get(1); + + // First check if arguments refer to different sides of join + Set argument0Vars = extractUnique(argument0); + Set argument1Vars = extractUnique(argument1); + if (!((leftVariables.containsAll(argument0Vars) && rightVariables.containsAll(argument1Vars)) + || (rightVariables.containsAll(argument0Vars) && leftVariables.containsAll(argument1Vars)))) { + // Arguments have a mix of join sides, use this predicate as a filter + joinFilters.add(predicate); + continue; + } + + // Next, check to see if first argument refers to left side and second argument to the right side + // If not, swap the arguments + if (leftVariables.containsAll(argument1Vars)) { + RowExpression temp = argument1; + argument1 = argument0; + argument0 = temp; + } + + // Next, check if we need to create new assignments for complex equi-join clauses + // E.g. leftVar = ADD(rightVar1, rightVar2) + if (!(argument0 instanceof VariableReferenceExpression)) { + VariableReferenceExpression newLeft = variableAllocator.newVariable(argument0); + newLeftAssignments.put(newLeft, argument0); + argument0 = newLeft; + } + + if (!(argument1 instanceof VariableReferenceExpression)) { + VariableReferenceExpression newRight = variableAllocator.newVariable(argument1); + newRightAssignments.put(newRight, argument1); + argument1 = newRight; + } + + joinClauses.add(new EquiJoinClause((VariableReferenceExpression) argument0, (VariableReferenceExpression) argument1)); + } + else { + joinFilters.add(predicate); + } + } + + return new JoinCondition(joinClauses.build(), joinFilters.build(), newLeftAssignments.build(), newRightAssignments.build()); } - private static EquiJoinClause toEquiJoinClause(CallExpression equality, Set leftVariables, VariableAllocator variableAllocator) + @VisibleForTesting + static class JoinCondition { - checkArgument(equality.getArguments().size() == 2, "Unexpected number of arguments in binary operator equals"); - VariableReferenceExpression leftVariable = (VariableReferenceExpression) equality.getArguments().get(0); - VariableReferenceExpression rightVariable = (VariableReferenceExpression) equality.getArguments().get(1); - EquiJoinClause equiJoinClause = new EquiJoinClause(leftVariable, rightVariable); - return leftVariables.contains(leftVariable) ? equiJoinClause : equiJoinClause.flip(); + List joinClauses; + List joinFilters; + Map newLeftAssignments; + Map newRightAssignments; + + public JoinCondition(List joinClauses, List joinFilters, + Map left, Map right) + { + this.joinClauses = joinClauses; + this.joinFilters = joinFilters; + this.newLeftAssignments = left; + this.newRightAssignments = right; + } + + public List getJoinClauses() + { + return joinClauses; + } + + public List getJoinFilters() + { + return joinFilters; + } + + public Map getNewLeftAssignments() + { + return newLeftAssignments; + } + + public Map getNewRightAssignments() + { + return newRightAssignments; + } } private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) @@ -459,8 +573,10 @@ static class MultiJoinNode { // Use a linked hash set to ensure optimizer is deterministic private final CanonicalJoinNode node; + private final Optional assignments; - public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List outputVariables) + public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List outputVariables, + Assignments assignments) { checkArgument(sources.size() > 1, "sources size is <= 1"); @@ -476,8 +592,23 @@ public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List ImmutableSet.of(filter), outputVariables); - List inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableList()); - checkArgument(inputVariables.containsAll(outputVariables), "inputs do not contain all output variables"); + ImmutableSet inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableSet()); + // We could have some output variables that were possibly generated from intermediate projects that were removed + // We will need to create a wrapper Project to add them back + Assignments.Builder builder = Assignments.builder(); + boolean nonIdentityAssignmentsFound = false; + for (VariableReferenceExpression outputVariable : outputVariables) { + if (inputVariables.contains(outputVariable)) { + builder.put(outputVariable, outputVariable); + continue; + } + checkState(assignments.getMap().containsKey(outputVariable), + "Output variable [%s] not found in input variables or intermediate assignments", outputVariable); + nonIdentityAssignmentsFound = true; + builder.put(outputVariable, assignments.get(outputVariable)); + } + + this.assignments = nonIdentityAssignmentsFound ? Optional.of(builder.build()) : Optional.empty(); } public RowExpression getFilter() @@ -495,6 +626,11 @@ public List getOutputVariables() return node.getOutputVariables(); } + public Optional getAssignments() + { + return assignments; + } + public static Builder builder() { return new Builder(); @@ -519,22 +655,25 @@ public boolean equals(Object obj) && getOutputVariables().equals(other.getOutputVariables()); } - static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) + static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) { // the number of sources is the number of joins + 1 - return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, functionResolution, determinismEvaluator).toMultiJoinNode(); + return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, handleComplexEquiJoins, functionResolution, determinismEvaluator).toMultiJoinNode(); } private static class JoinNodeFlattener { private final LinkedHashSet sources = new LinkedHashSet<>(); - private final List filters = new ArrayList<>(); + private final Assignments assignments; + private final boolean handleComplexEquiJoins; + private List filters = new ArrayList<>(); private final List outputVariables; private final FunctionResolution functionResolution; private final DeterminismEvaluator determinismEvaluator; private final Lookup lookup; - JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) + JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, + DeterminismEvaluator determinismEvaluator) { requireNonNull(node, "node is null"); checkState(node.getType() == INNER, "join type must be INNER"); @@ -542,13 +681,39 @@ private static class JoinNodeFlattener this.lookup = requireNonNull(lookup, "lookup is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null"); - flattenNode(node, sourceLimit); + this.handleComplexEquiJoins = handleComplexEquiJoins; + Assignments.Builder intermediateAssignments = Assignments.builder(); + flattenNode(node, sourceLimit, intermediateAssignments); + this.assignments = intermediateAssignments.build(); + rewriteFilterWithInlinedAssignments(intermediateAssignments.build()); + } + + private void rewriteFilterWithInlinedAssignments(Assignments assignments) + { + ImmutableList.Builder modifiedFilters = ImmutableList.builder(); + filters.forEach(filter -> modifiedFilters.add(replaceExpression(filter, assignments.getMap()))); + filters = modifiedFilters.build(); } - private void flattenNode(PlanNode node, int limit) + private void flattenNode(PlanNode node, int limit, Assignments.Builder assignmentsBuilder) { PlanNode resolved = lookup.resolve(node); + if (resolved instanceof ProjectNode) { + ProjectNode projectNode = (ProjectNode) resolved; + // A ProjectNode could be 'hiding' a join source by building an assignment of a complex equi-join criteria like `left.key = right1.key1 + right1.key2` + // We open up the join space by tracking the assignments from this Project node; these will be inlined into the overall filters once we finish + // traversing the join graph + if (handleComplexEquiJoins && lookup.resolve(projectNode.getSource()) instanceof JoinNode) { + assignmentsBuilder.putAll(projectNode.getAssignments()); + flattenNode(projectNode.getSource(), limit, assignmentsBuilder); + } + else { + sources.add(node); + } + return; + } + // (limit - 2) because you need to account for adding left and right side if (!(resolved instanceof JoinNode) || (sources.size() > (limit - 2))) { sources.add(node); @@ -562,8 +727,8 @@ private void flattenNode(PlanNode node, int limit) } // we set the left limit to limit - 1 to account for the node on the right - flattenNode(joinNode.getLeft(), limit - 1); - flattenNode(joinNode.getRight(), limit); + flattenNode(joinNode.getLeft(), limit - 1, assignmentsBuilder); + flattenNode(joinNode.getRight(), limit, assignmentsBuilder); joinNode.getCriteria().stream() .map(criteria -> toRowExpression(criteria, functionResolution)) .forEach(filters::add); @@ -572,7 +737,7 @@ private void flattenNode(PlanNode node, int limit) MultiJoinNode toMultiJoinNode() { - return new MultiJoinNode(sources, and(filters), outputVariables); + return new MultiJoinNode(sources, and(filters), outputVariables, assignments); } } @@ -602,7 +767,7 @@ public Builder setOutputVariables(VariableReferenceExpression... outputVariables public MultiJoinNode build() { - return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputVariables); + return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputVariables, Assignments.builder().build()); } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 64d281714969b..eaa35dc4ef5f4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -243,7 +243,8 @@ public void testDefaults() .setAddPartialNodeForRowNumberWithLimitEnabled(true) .setInferInequalityPredicates(false) .setPullUpExpressionFromLambdaEnabled(false) - .setRewriteConstantArrayContainsToInEnabled(false)); + .setRewriteConstantArrayContainsToInEnabled(false) + .setHandleComplexEquiJoins(true)); } @Test @@ -435,6 +436,7 @@ public void testExplicitPropertyMappings() .put("optimizer.infer-inequality-predicates", "true") .put("optimizer.pull-up-expression-from-lambda", "true") .put("optimizer.rewrite-constant-array-contains-to-in", "true") + .put("optimizer.handle-complex-equi-joins", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -623,7 +625,8 @@ public void testExplicitPropertyMappings() .setAddPartialNodeForRowNumberWithLimitEnabled(false) .setInferInequalityPredicates(true) .setPullUpExpressionFromLambdaEnabled(true) - .setRewriteConstantArrayContainsToInEnabled(true); + .setRewriteConstantArrayContainsToInEnabled(true) + .setHandleComplexEquiJoins(false); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java index 63dea91103d9f..4e386ab85427f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java @@ -439,6 +439,7 @@ public void testNonPushedDownJoinFilterRemoval() "SELECT 1 FROM part t0, part t1, part t2 " + "WHERE t0.partkey = t1.partkey AND t0.partkey = t2.partkey " + "AND t0.size + t1.size = t2.size", + noJoinReordering(), anyTree( join( INNER, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index 9d0f7175ad91d..062cca4bbc5c9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -100,7 +100,7 @@ /** * RowExpression visitor which verifies if given expression (actual) is matching other RowExpression given as context (expected). */ -final class RowExpressionVerifier +public final class RowExpressionVerifier extends AstVisitor { // either use variable or input reference for symbol mapping @@ -110,7 +110,7 @@ final class RowExpressionVerifier private final FunctionResolution functionResolution; private final Set lambdaArguments; - RowExpressionVerifier(SymbolAliases symbolAliases, Metadata metadata, Session session) + public RowExpressionVerifier(SymbolAliases symbolAliases, Metadata metadata, Session session) { this.symbolAliases = requireNonNull(symbolAliases, "symbolLayout is null"); this.metadata = requireNonNull(metadata, "metadata is null"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java index d60ea7b95e8a7..37e2f15062bed 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -15,46 +15,69 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.CachingCostProvider; import com.facebook.presto.cost.CachingStatsProvider; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.CostProvider; import com.facebook.presto.cost.PlanCostEstimate; import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.LogicalPropertiesProvider; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.DeterminismEvaluator; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.assertions.RowExpressionVerifier; +import com.facebook.presto.sql.planner.assertions.SymbolAliases; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator; +import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.JoinCondition; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.MultiJoinNode; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.Arrays; import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; +import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.planner.optimizations.JoinNodeUtils.toRowExpression; +import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; public class TestJoinEnumerator { @@ -62,14 +85,20 @@ public class TestJoinEnumerator private Metadata metadata; private DeterminismEvaluator determinismEvaluator; private FunctionResolution functionResolution; + private PlanBuilder planBuilder; + private TestingRowExpressionTranslator rowExpressionTranslator; + private Session session; @BeforeClass public void setUp() { - queryRunner = new LocalQueryRunner(testSessionBuilder().build()); + session = testSessionBuilder().build(); + queryRunner = new LocalQueryRunner(session); metadata = queryRunner.getMetadata(); determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata); functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); + rowExpressionTranslator = new TestingRowExpressionTranslator(metadata); } @AfterClass(alwaysRun = true) @@ -109,7 +138,8 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() MultiJoinNode multiJoinNode = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(p.values(a1), p.values(b1))), TRUE_CONSTANT, - ImmutableList.of(a1, b1)); + ImmutableList.of(a1, b1), + Assignments.builder().build()); JoinEnumerator joinEnumerator = new JoinEnumerator( new CostComparator(1, 1, 1), multiJoinNode.getFilter(), @@ -122,6 +152,94 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() assertEquals(actual.getCost(), PlanCostEstimate.infinite()); } + @Test + public void testJoinClauseAndFilterInference() + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put("a", BIGINT); + builder.put("b", BIGINT); + builder.put("c", BIGINT); + builder.put("d", BIGINT); + Map variableMap = builder.build(); + VariableReferenceExpression a = variable("a", variableMap.get("a")); + VariableReferenceExpression b = variable("b", variableMap.get("b")); + VariableReferenceExpression c = variable("c", variableMap.get("c")); + VariableReferenceExpression d = variable("d", variableMap.get("d")); + + SymbolAliases.Builder newAliases = SymbolAliases.builder(); + newAliases.put("A", new SymbolReference("a")); + newAliases.put("B", new SymbolReference("b")); + newAliases.put("C", new SymbolReference("c")); + newAliases.put("D", new SymbolReference("d")); + SymbolAliases symbolAliases = newAliases.build(); + + // Simple join predicates on variable references + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B", null); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b", "c = d"), ImmutableSet.of(a, c), ImmutableSet.of(b, d), "A = B AND C = D", null); + // Complex join predicate - All variables must be from one join side to have the predicate be an equi-join clause + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C", null); + // Left and right side designation can be switched + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C", null); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C + 1", null); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C + 1", null); + // If a predicate has a mix of variables from left & right sides, the predicate is treated as a filter + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = c"), ImmutableSet.of(a), ImmutableSet.of(b, c), null, "A + B = C"); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = 1"), ImmutableSet.of(a), ImmutableSet.of(b), null, "A + B = 1"); + // Test with multiple equi-join conditions and filters + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = ABS(b)", "a = ceil(b-c)", "b = c + 10"), + ImmutableSet.of(a), ImmutableSet.of(b, c), "A = abs(B) AND A = ceil(B-C)", "B = C + 10"); + } + + private List toRowExpressionList(Map variableTypeMap, String... predicates) + { + return Arrays.stream(predicates) + .map(p -> rowExpressionTranslator.translate(p, variableTypeMap)) + .collect(Collectors.toList()); + } + + private void assertJoinCondition(SymbolAliases symbolAliases, List joinPredicates, Set leftVariables, + Set rightVariables, String expectedEquiJoinClause, String expectedJoinFilter) + { + RowExpressionVerifier verifier = new RowExpressionVerifier(symbolAliases, metadata, session); + JoinEnumerator joinEnumerator = new JoinEnumerator( + new CostComparator(1, 1, 1), + TRUE_CONSTANT, + createContext(), + determinismEvaluator, + functionResolution, + metadata); + + JoinCondition joinConditions = joinEnumerator.extractJoinConditions(joinPredicates, + leftVariables, rightVariables, new VariableAllocator()); + + Optional equiJoinExpressionInlined = joinConditions.getJoinClauses().stream() + .map(criteria -> toRowExpression(criteria, functionResolution)) + // We may have made left or right assignments to build the equi join clause + // We inline these assignments for building the equi join clause to verify + .map(expression -> replaceExpression(expression, joinConditions.getNewLeftAssignments())) + .map(expression -> replaceExpression(expression, joinConditions.getNewRightAssignments())) + .reduce(LogicalRowExpressions::and); + + if (equiJoinExpressionInlined.isPresent()) { + assertNotNull(expectedEquiJoinClause); + assertTrue(verifier.process(expression(expectedEquiJoinClause), equiJoinExpressionInlined.get())); + } + else { + assertNull(expectedEquiJoinClause); + } + + Optional joinFilter = joinConditions.getJoinFilters().stream() + .reduce(LogicalRowExpressions::and); + + if (joinFilter.isPresent()) { + assertNotNull(expectedJoinFilter); + assertTrue(verifier.process(expression(expectedJoinFilter), joinFilter.get())); + } + else { + assertNull(expectedJoinFilter); + } + } + private Rule.Context createContext() { PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java index 6ba4bc4f7c9b2..dd057fe418232 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java @@ -15,7 +15,9 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; @@ -86,7 +88,7 @@ public void testDoesNotAllowOuterJoin() ImmutableList.of(equiJoinClause(a1, b1)), ImmutableList.of(a1, b1), Optional.empty()); - toMultiJoinNode(outerJoin, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator); + toMultiJoinNode(outerJoin, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator); } @Test @@ -116,7 +118,7 @@ public void testDoesNotConvertNestedOuterJoins() .setSources(leftJoin, valuesC).setFilter(createEqualsExpression(a1, c1)) .setOutputVariables(a1, b1, c1) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator), expected); } @Test @@ -149,7 +151,7 @@ public void testRetainsOutputSymbols() .setFilter(and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1))) .setOutputVariables(a1, b1) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator), expected); } @Test @@ -206,8 +208,9 @@ public void testCombinesCriteriaAndFilters() MultiJoinNode expected = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(valuesA, valuesB, valuesC)), and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1), bcFilter, abcFilter), - ImmutableList.of(a1, b1, b2, c1, c2)); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); + ImmutableList.of(a1, b1, b2, c1, c2), + Assignments.builder().build()); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator), expected); } @Test @@ -258,7 +261,7 @@ public void testConvertsBushyTrees() .setFilter(and(createEqualsExpression(a1, b1), createEqualsExpression(a1, c1), createEqualsExpression(d1, e1), createEqualsExpression(d2, e2), createEqualsExpression(b1, e1))) .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, true, functionResolution, determinismEvaluator), expected); } @Test @@ -311,10 +314,56 @@ public void testMoreThanJoinLimit() .setFilter(and(createEqualsExpression(a1, c1), createEqualsExpression(b1, e1))) .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), 2, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), 2, true, functionResolution, determinismEvaluator), expected); + } + + @Test + public void testProjectNodesBetweenJoinNodesAreFlattenedForComplexEquiJoins() + { + PlanBuilder p = planBuilder(); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression c1 = p.variable("C1"); + VariableReferenceExpression sum = p.variable("SUM"); + + ValuesNode valuesA = p.values(a1); + ValuesNode valuesB = p.values(b1); + ValuesNode valuesC = p.values(c1); + Assignments sumAssignment = Assignments.builder().put(sum, createAddExpression(a1, b1)).build(); + + ProjectNode intermediateProject = p.project(sumAssignment, p.join( + INNER, + valuesA, + valuesB, + ImmutableList.of(equiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty())); + JoinNode joinNode = p.join( + INNER, + intermediateProject, + valuesC, + ImmutableList.of(equiJoinClause(sum, c1)), + ImmutableList.of(), + Optional.empty()); + + MultiJoinNode expected = MultiJoinNode.builder() + .setSources(valuesA, valuesB, valuesC) + .setFilter(and(createEqualsExpression(a1, b1), createEqualsExpression(createAddExpression(a1, b1), c1))) + .setOutputVariables() + .build(); + assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, /*handleComplexEquiJoins*/ true, functionResolution, determinismEvaluator), expected); + + // Negative test - when handleComplexEquiJoins = false, we have a split join space; the ProjectNode is not flattened + expected = MultiJoinNode.builder() + .setSources(intermediateProject, valuesC) + .setFilter(createEqualsExpression(sum, c1)) + .setOutputVariables() + .build(); + + assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, /*handleComplexEquiJoins*/ false, functionResolution, determinismEvaluator), expected); } - private RowExpression createEqualsExpression(VariableReferenceExpression left, VariableReferenceExpression right) + private RowExpression createEqualsExpression(RowExpression left, RowExpression right) { return call( OperatorType.EQUAL.name(), @@ -323,6 +372,15 @@ private RowExpression createEqualsExpression(VariableReferenceExpression left, V ImmutableList.of(left, right)); } + private RowExpression createAddExpression(RowExpression left, RowExpression right) + { + return call( + OperatorType.ADD.name(), + functionResolution.arithmeticFunction(OperatorType.ADD, left.getType(), right.getType()), + BIGINT, + ImmutableList.of(left, right)); + } + private EquiJoinClause equiJoinClause(VariableReferenceExpression variable1, VariableReferenceExpression variable2) { return new EquiJoinClause(variable1, variable2); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java index c29c366672392..155c95d33b8e8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; @@ -32,12 +33,14 @@ import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Optional; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; +import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -49,8 +52,13 @@ import static com.facebook.presto.metadata.FunctionAndTypeManager.qualifyObjectName; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.BROADCAST; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; @@ -59,6 +67,7 @@ import static com.facebook.presto.sql.relational.Expressions.variable; public class TestReorderJoins + extends BasePlanTest { private RuleTester tester; private FunctionResolution functionResolution; @@ -67,6 +76,37 @@ public class TestReorderJoins private static final ImmutableList> TWO_ROWS = ImmutableList.of(ImmutableList.of(), ImmutableList.of()); private static final QualifiedName RANDOM = QualifiedName.of("random"); + @DataProvider + public static Object[][] tableSpecificationPermutations() + { + return new Object[][] { + {"supplier s, partsupp ps, customer c, orders o"}, + {"supplier s, partsupp ps, orders o, customer c"}, + {"supplier s, customer c, partsupp ps, orders o"}, + {"supplier s, customer c, orders o, partsupp ps"}, + {"supplier s, orders o, partsupp ps, customer c"}, + {"supplier s, orders o, customer c, partsupp ps"}, + {"partsupp ps, supplier s, customer c, orders o"}, + {"partsupp ps, supplier s, orders o, customer c"}, + {"partsupp ps, customer c, supplier s, orders o"}, + {"partsupp ps, customer c, orders o, supplier s"}, + {"partsupp ps, orders o, supplier s, customer c"}, + {"partsupp ps, orders o, customer c, supplier s"}, + {"customer c, supplier s, partsupp ps, orders o"}, + {"customer c, supplier s, orders o, partsupp ps"}, + {"customer c, partsupp ps, supplier s, orders o"}, + {"customer c, partsupp ps, orders o, supplier s"}, + {"customer c, orders o, supplier s, partsupp ps"}, + {"customer c, orders o, partsupp ps, supplier s"}, + {"orders o, supplier s, partsupp ps, customer c"}, + {"orders o, supplier s, customer c, partsupp ps"}, + {"orders o, partsupp ps, supplier s, customer c"}, + {"orders o, partsupp ps, customer c, supplier s"}, + {"orders o, customer c, supplier s, partsupp ps"}, + {"orders o, customer c, partsupp ps, supplier s"} + }; + } + @BeforeClass public void setUp() { @@ -74,7 +114,8 @@ public void setUp() ImmutableList.of(), ImmutableMap.of( JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name(), - JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name()), + JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name(), + HANDLE_COMPLEX_EQUI_JOINS, "true"), Optional.of(4)); this.functionResolution = new FunctionResolution(tester.getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver()); } @@ -550,6 +591,119 @@ public void testReorderAndReplicate() values(ImmutableMap.of("A1", 0)))); } + /** + * This test asserts that join re-ordering works as expected for complex equi join clauses ('s.acctbal = c.acctbal + o.totalprice') + * and works irrespective of the order in which tables are specified in the FROM clause + * + * @param tableSpecificationOrder The table specification order + */ + @Test(dataProvider = "tableSpecificationPermutations") + public void testComplexEquiJoinCriteria(String tableSpecificationOrder) + { + // For a full connected join graph, we don't see any CrossJoins + String query = "select 1 from " + tableSpecificationOrder + " where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal = c.acctbal + o.totalprice"; + PlanMatchPattern expectedPlan = + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree(tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("SUM", "S_ACCTBAL")), + anyTree( + project(ImmutableMap.of("SUM", expression("C_ACCTBAL + O_TOTALPRICE")), + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey"))))))); + assertPlan(query, expectedPlan); + + // The plan is identical to the plan for the fully spelled out version of the Join + String fullQuery = "select 1 from (supplier s inner join partsupp ps on s.suppkey = ps.suppkey) inner join (orders o inner join customer c on c.custkey = o.custkey) " + + " on s.acctbal = c.acctbal + o.totalprice"; + assertPlan(fullQuery, expectedPlan); + } + + @Test + public void testComplexEquiJoinCriteriaForDisjointGraphs() + { + // If the join clause is written with the Left/Right side referring to both sides of a Join node, an equi-join condition cannot be inferred + // and the join space is broken up. Hence, we observe a CrossJoin node + assertPlan("select 1 from supplier s, partsupp ps, customer c, orders o where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal - c.acctbal = o.totalprice", + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("C_CUSTKEY", "O_CUSTKEY"), equiJoinClause("SUBTRACT", "O_TOTALPRICE")), + anyTree( + project(ImmutableMap.of("SUBTRACT", expression("S_ACCTBAL - C_ACCTBAL")), + join(INNER, + ImmutableList.of(), //CrossJoin + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree(tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey")))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice")))))); + + // The table specification order determines the join order for such cases + // With the below table specification order, the planner adds the complex equi-join condition as a FilterNode on top of a JoinNode + assertPlan("select 1 from orders o, customer c, supplier s, partsupp ps where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal - c.acctbal = o.totalprice", + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree( + tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), + anyTree( + filter("O_TOTALPRICE = S_ACCTBAL - C_ACCTBAL", + join(INNER, + ImmutableList.of(), //CrossJoin + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree(tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey"))))))))); + + // For sub-graphs that are fully connected, join-reordering works with complex predicates as expected + // The rest of the join graph is connected using a CrossJoin + assertPlan("select 1 " + + "from orders o, customer c, supplier s, partsupp ps, part p " + + "where s.suppkey = ps.suppkey " + + " and c.custkey = o.custkey " + + " and s.acctbal = c.acctbal + o.totalprice" + + " and ps.partkey - p.partkey = 0 ", + anyTree( + filter("PS_PARTKEY - P_PARTKEY = 0", + join(INNER, + ImmutableList.of(), // CrossJoin + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree( + tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey", "PS_PARTKEY", "partkey"))), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("SUM", "S_ACCTBAL")), + anyTree( + project(ImmutableMap.of("SUM", expression("C_ACCTBAL + O_TOTALPRICE")), + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey")))))), + anyTree( + tableScan("part", ImmutableMap.of("P_PARTKEY", "partkey"))))))); + } + private RuleAssert assertReorderJoins() { return tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), tester.getMetadata())); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java index 369b0277abe66..374310940b85b 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java @@ -24,6 +24,7 @@ import com.google.common.collect.Iterables; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOINS_NOT_NULL_INFERENCE_STRATEGY; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -470,6 +471,29 @@ public void testJoinWithComplexExpressions3() "SELECT SUM(custkey) FROM lineitem JOIN orders ON lineitem.orderkey + 1 = orders.orderkey + 1", // H2 takes a million years because it can't join efficiently on a non-indexed field/expression "SELECT SUM(custkey) FROM lineitem JOIN orders ON lineitem.orderkey = orders.orderkey "); + + Session handleComplexEquiJoins = Session.builder(getSession()) + .setSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, "true") + .build(); + + assertQueryWithSameQueryRunner( + handleComplexEquiJoins, + "select c.custkey, ps.partkey, s.suppkey, o.orderkey " + + "from customer c, " + + " partsupp ps, " + + " orders o, " + + " supplier s " + + "where s.suppkey = ps.suppkey " + + " and c.custkey = o.custkey " + + " and s.nationkey + ps.partkey = c.nationkey " + + "order by c.custkey, ps.partkey, s.suppkey, o.orderkey", + noJoinReordering(), + "select c.custkey, ps.partkey, s.suppkey, o.orderkey " + + "from (customer c inner join orders o ON c.custkey = o.custkey) " + + " inner join " + + " (partsupp ps inner join supplier s ON s.suppkey = ps.suppkey) " + + " on s.nationkey + ps.partkey = c.nationkey " + + "order by c.custkey, ps.partkey, s.suppkey, o.orderkey"); } @Test diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 1ab20c4fbec08..f561fba080b1e 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -169,6 +169,12 @@ protected void assertQueryWithSameQueryRunner(Session session, @Language("SQL") QueryAssertions.assertQuery(queryRunner, session, actual, queryRunner, expected, false, false); } + protected void assertQueryWithSameQueryRunner(Session actualSession, @Language("SQL") String actual, Session expectedSession, @Language("SQL") String expected) + { + checkArgument(!actual.equals(expected)); + QueryAssertions.assertQuery(queryRunner, actualSession, actual, queryRunner, expectedSession, expected, false, false); + } + protected void assertQuery(Session session, @Language("SQL") String actual, @Language("SQL") String expected) { QueryAssertions.assertQuery(queryRunner, session, actual, expectedQueryRunner, expected, false, false);