From 06dbf2795bef9e29ea7af38362aad499cad36eeb Mon Sep 17 00:00:00 2001 From: Anant Aneja <1797669+aaneja@users.noreply.github.com> Date: Mon, 30 Oct 2023 09:35:11 +0530 Subject: [PATCH] Enhance join reordering to work with non-simple equi-join clauses Join predicates like `left.key = right1.key1 + right2.key2` can reduce the join space by appearing as Project nodes or Join noes with no equi-join clauses in the join graph. This commit fixes this behavior by removing any intermediate Projects in the join graph and only creating them on-the-fly while choosing the join order --- .../sql/planner/TestTpcdsCostBasedPlan.java | 4 +- .../sql/planner/TestTpchCostBasedPlan.java | 4 +- .../resources/sql/presto/tpcds/q02.plan.txt | 33 +- .../resources/sql/presto/tpcds/q59.plan.txt | 4 +- .../presto/SystemSessionProperties.java | 11 + .../presto/sql/analyzer/FeaturesConfig.java | 14 + .../presto/sql/planner/CanonicalJoinNode.java | 12 + .../planner/iterative/rule/ReorderJoins.java | 329 +++++++++++++++--- .../sql/planner/plan/AssignmentUtils.java | 12 + .../sql/analyzer/TestFeaturesConfig.java | 7 +- .../presto/sql/planner/TestDynamicFilter.java | 1 + .../assertions/RowExpressionVerifier.java | 4 +- .../iterative/rule/TestJoinEnumerator.java | 122 ++++++- .../iterative/rule/TestJoinNodeFlattener.java | 154 +++++++- .../iterative/rule/TestReorderJoins.java | 197 ++++++++++- .../presto/tests/AbstractTestJoinQueries.java | 24 ++ .../tests/AbstractTestQueryFramework.java | 7 + 17 files changed, 857 insertions(+), 82 deletions(-) diff --git a/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpcdsCostBasedPlan.java b/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpcdsCostBasedPlan.java index 98719c32aa610..b66998e5509ed 100644 --- a/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpcdsCostBasedPlan.java +++ b/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpcdsCostBasedPlan.java @@ -24,6 +24,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_JOINS_WITH_EMPTY_SOURCES; @@ -54,7 +55,8 @@ public TestTpcdsCostBasedPlan() .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel .setSystemProperty(JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name()) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) - .setSystemProperty(OPTIMIZE_JOINS_WITH_EMPTY_SOURCES, "false"); + .setSystemProperty(OPTIMIZE_JOINS_WITH_EMPTY_SOURCES, "false") + .setSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, "true"); LocalQueryRunner queryRunner = LocalQueryRunner.queryRunnerWithFakeNodeCountForStats(sessionBuilder.build(), 8); queryRunner.createCatalog( diff --git a/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpchCostBasedPlan.java b/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpchCostBasedPlan.java index 463f514468351..ecb20d8918652 100644 --- a/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpchCostBasedPlan.java +++ b/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/TestTpchCostBasedPlan.java @@ -25,6 +25,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -54,7 +55,8 @@ public TestTpchCostBasedPlan() .setSchema("sf3000.0") .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel .setSystemProperty(JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name()) - .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()); + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()) + .setSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, "true"); LocalQueryRunner queryRunner = LocalQueryRunner.queryRunnerWithFakeNodeCountForStats(sessionBuilder.build(), 8); queryRunner.createCatalog( diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt index 44f85122d99f7..8418d929db9d0 100644 --- a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt @@ -2,6 +2,22 @@ remote exchange (GATHER, SINGLE, []) local exchange (GATHER, UNKNOWN, []) remote exchange (REPARTITION, ROUND_ROBIN, []) join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, [subtract_400]) + join (INNER, PARTITIONED): + final aggregation over (d_week_seq_232) + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [d_week_seq_232]) + partial aggregation over (d_week_seq_232) + join (INNER, REPLICATED): + remote exchange (REPARTITION, ROUND_ROBIN, []) + scan web_sales + scan catalog_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [d_week_seq_316]) + scan date_dim join (INNER, PARTITIONED): final aggregation over (d_week_seq) local exchange (GATHER, SINGLE, []) @@ -17,20 +33,3 @@ remote exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, [d_week_seq_83]) scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [subtract]) - join (INNER, PARTITIONED): - final aggregation over (d_week_seq_232) - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [d_week_seq_232]) - partial aggregation over (d_week_seq_232) - join (INNER, REPLICATED): - remote exchange (REPARTITION, ROUND_ROBIN, []) - scan web_sales - scan catalog_sales - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [d_week_seq_316]) - scan date_dim diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt index 23c58414ed691..e9fc6a9c0d2ec 100644 --- a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt @@ -1,7 +1,7 @@ local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, [d_week_seq, s_store_id]) + remote exchange (REPARTITION, HASH, [d_week_seq, d_week_seq_267, s_store_id]) join (INNER, REPLICATED): join (INNER, REPLICATED): final aggregation over (d_week_seq, ss_store_sk) @@ -20,7 +20,7 @@ local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan store local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [s_store_id_235, subtract]) + remote exchange (REPARTITION, HASH, [d_week_seq_147, d_week_seq_63, s_store_id_235]) join (INNER, REPLICATED): join (INNER, REPLICATED): final aggregation over (d_week_seq_147, ss_store_sk_127) diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index bb05f086f9566..9237ff568815d 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -295,6 +295,7 @@ public final class SystemSessionProperties public static final String INFER_INEQUALITY_PREDICATES = "infer_inequality_predicates"; public static final String ENABLE_HISTORY_BASED_SCALED_WRITER = "enable_history_based_scaled_writer"; public static final String REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN = "remove_redundant_cast_to_varchar_in_join"; + public static final String HANDLE_COMPLEX_EQUI_JOINS = "handle_complex_equi_joins"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled"; @@ -1775,6 +1776,11 @@ public SystemSessionProperties( REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "If both left and right side of join clause are varchar cast from int/bigint, remove the cast here", featuresConfig.isRemoveRedundantCastToVarcharInJoin(), + false), + booleanProperty( + HANDLE_COMPLEX_EQUI_JOINS, + "Handle complex equi-join conditions to open up join space for join reordering", + featuresConfig.getHandleComplexEquiJoins(), false)); } @@ -2959,4 +2965,9 @@ public static boolean isRemoveRedundantCastToVarcharInJoinEnabled(Session sessio { return session.getSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, Boolean.class); } + + public static boolean shouldHandleComplexEquiJoins(Session session) + { + return session.getSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index e0dba0b0334a3..142486b2d73ae 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -281,6 +281,7 @@ public class FeaturesConfig private boolean rewriteConstantArrayContainsToIn; private boolean preProcessMetadataCalls; + private boolean handleComplexEquiJoins; private boolean useHBOForScaledWriters; private boolean removeRedundantCastToVarcharInJoin = true; @@ -2831,4 +2832,17 @@ public FeaturesConfig setRemoveRedundantCastToVarcharInJoin(boolean removeRedund this.removeRedundantCastToVarcharInJoin = removeRedundantCastToVarcharInJoin; return this; } + + public boolean getHandleComplexEquiJoins() + { + return handleComplexEquiJoins; + } + + @Config("optimizer.handle-complex-equi-joins") + @ConfigDescription("Handle complex equi-join conditions to open up join space for join reordering") + public FeaturesConfig setHandleComplexEquiJoins(boolean handleComplexEquiJoins) + { + this.handleComplexEquiJoins = handleComplexEquiJoins; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalJoinNode.java index ad86d227e9473..af73714f21ca2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalJoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalJoinNode.java @@ -125,4 +125,16 @@ public int hashCode() { return Objects.hash(sources, type, criteria, filters, outputVariables); } + + @Override + public String toString() + { + return "CanonicalJoinNode{" + + "sources=" + sources + + ", type=" + type + + ", criteria=" + criteria + + ", filters=" + filters + + ", outputVariables=" + outputVariables + + '}'; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java index 7fe3c95762a34..d9dd3dd1acd83 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -24,10 +24,12 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; @@ -35,7 +37,6 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.CanonicalJoinNode; import com.facebook.presto.sql.planner.EqualityInference; -import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -50,9 +51,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; +import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -66,18 +69,23 @@ import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; import static com.facebook.presto.SystemSessionProperties.getMaxReorderedJoins; -import static com.facebook.presto.common.function.OperatorType.EQUAL; +import static com.facebook.presto.SystemSessionProperties.shouldHandleComplexEquiJoins; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.expressions.LogicalRowExpressions.and; import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts; +import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression; +import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.AUTOMATIC; import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; +import static com.facebook.presto.sql.planner.PlannerUtils.addProjections; +import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique; import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isBelowMaxBroadcastSize; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.MultiJoinNode.toMultiJoinNode; import static com.facebook.presto.sql.planner.optimizations.JoinNodeUtils.toRowExpression; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.getNonIdentityAssignments; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; @@ -148,7 +156,8 @@ public String getStatsSource() @Override public Result apply(JoinNode joinNode, Captures captures, Context context) { - MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, context.getLookup(), getMaxReorderedJoins(context.getSession()), functionResolution, determinismEvaluator); + MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, context.getLookup(), getMaxReorderedJoins(context.getSession()), shouldHandleComplexEquiJoins(context.getSession()), + functionResolution, determinismEvaluator); JoinEnumerator joinEnumerator = new JoinEnumerator( costComparator, multiJoinNode.getFilter(), @@ -156,12 +165,26 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) determinismEvaluator, functionResolution, metadata); + JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputVariables()); + if (!result.getPlanNode().isPresent()) { return Result.empty(); } + statsSource = context.getStatsProvider().getStats(joinNode).getSourceInfo().getSourceInfoName(); - return Result.ofPlanNode(result.getPlanNode().get()); + + PlanNode transformedPlan = result.getPlanNode().get(); + if (!multiJoinNode.getAssignments().isEmpty()) { + transformedPlan = new ProjectNode( + transformedPlan.getSourceLocation(), + context.getIdAllocator().getNextId(), + transformedPlan, + multiJoinNode.getAssignments(), + LOCAL); + } + + return Result.ofPlanNode(transformedPlan); } @VisibleForTesting @@ -180,6 +203,7 @@ static class JoinEnumerator private final Context context; private final Map, JoinEnumerationResult> memo = new HashMap<>(); + private final FunctionResolution functionResolution; @VisibleForTesting JoinEnumerator(CostComparator costComparator, RowExpression filter, Context context, DeterminismEvaluator determinismEvaluator, FunctionResolution functionResolution, Metadata metadata) @@ -195,6 +219,7 @@ static class JoinEnumerator this.metadata = requireNonNull(metadata, "metadata is null"); this.allFilterInference = createEqualityInference(metadata, filter); this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, functionResolution, metadata.getFunctionAndTypeManager()); + this.functionResolution = functionResolution; } private JoinEnumerationResult chooseJoinOrder(LinkedHashSet sources, List outputVariables) @@ -269,28 +294,31 @@ JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet private JoinEnumerationResult createJoin(LinkedHashSet leftSources, LinkedHashSet rightSources, List outputVariables) { - Set leftVariables = leftSources.stream() + HashSet leftVariables = leftSources.stream() .flatMap(node -> node.getOutputVariables().stream()) - .collect(toImmutableSet()); - Set rightVariables = rightSources.stream() + .collect(toCollection(HashSet::new)); + HashSet rightVariables = rightSources.stream() .flatMap(node -> node.getOutputVariables().stream()) - .collect(toImmutableSet()); + .collect(toCollection(HashSet::new)); List joinPredicates = getJoinPredicates(leftVariables, rightVariables); - List joinConditions = joinPredicates.stream() - .filter(JoinEnumerator::isJoinEqualityCondition) - .map(predicate -> toEquiJoinClause((CallExpression) predicate, leftVariables, context.getVariableAllocator())) - .collect(toImmutableList()); - if (joinConditions.isEmpty()) { + + VariableAllocator variableAllocator = context.getVariableAllocator(); + JoinCondition joinConditions = extractJoinConditions(joinPredicates, leftVariables, rightVariables, variableAllocator); + List joinClauses = joinConditions.getJoinClauses(); + List joinFilters = joinConditions.getJoinFilters(); + + //Update the left & right variable sets with any new variables generated + leftVariables.addAll(joinConditions.getNewLeftAssignments().keySet()); + rightVariables.addAll(joinConditions.getNewRightAssignments().keySet()); + + if (joinClauses.isEmpty()) { return INFINITE_COST_RESULT; } - List joinFilters = joinPredicates.stream() - .filter(predicate -> !isJoinEqualityCondition(predicate)) - .collect(toImmutableList()); Set requiredJoinVariables = ImmutableSet.builder() .addAll(outputVariables) - .addAll(VariablesExtractor.extractUnique(joinPredicates)) + .addAll(extractUnique(joinPredicates)) .build(); JoinEnumerationResult leftResult = getJoinSource( @@ -306,6 +334,13 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li } PlanNode left = leftResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present")); + if (!joinConditions.getNewLeftAssignments().isEmpty()) { + ImmutableMap.Builder assignments = ImmutableMap.builder(); + left.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable)); + assignments.putAll(joinConditions.getNewLeftAssignments()); + + left = addProjections(left, idAllocator, assignments.build()); + } JoinEnumerationResult rightResult = getJoinSource( rightSources, @@ -320,6 +355,13 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li } PlanNode right = rightResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present")); + if (!joinConditions.getNewRightAssignments().isEmpty()) { + ImmutableMap.Builder assignments = ImmutableMap.builder(); + right.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable)); + assignments.putAll(joinConditions.getNewRightAssignments()); + + right = addProjections(right, idAllocator, assignments.build()); + } // sort output variables so that the left input variables are first List sortedOutputVariables = Stream.concat(left.getOutputVariables().stream(), right.getOutputVariables().stream()) @@ -332,7 +374,7 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li INNER, left, right, - joinConditions, + joinClauses, sortedOutputVariables, joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)), Optional.empty(), @@ -388,22 +430,103 @@ private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List< return chooseJoinOrder(nodes, outputVariables); } - private static boolean isJoinEqualityCondition(RowExpression expression) + @VisibleForTesting + JoinCondition extractJoinConditions(List joinPredicates, + Set leftVariables, + Set rightVariables, + VariableAllocator variableAllocator) { - return expression instanceof CallExpression - && ((CallExpression) expression).getDisplayName().equals(EQUAL.getFunctionName().getObjectName()) - && ((CallExpression) expression).getArguments().size() == 2 - && ((CallExpression) expression).getArguments().get(0) instanceof VariableReferenceExpression - && ((CallExpression) expression).getArguments().get(1) instanceof VariableReferenceExpression; + ImmutableMap.Builder newLeftAssignments = ImmutableMap.builder(); + ImmutableMap.Builder newRightAssignments = ImmutableMap.builder(); + + ImmutableList.Builder joinClauses = ImmutableList.builder(); + ImmutableList.Builder joinFilters = ImmutableList.builder(); + + for (RowExpression predicate : joinPredicates) { + if (predicate instanceof CallExpression + && functionResolution.isEqualFunction(((CallExpression) predicate).getFunctionHandle()) + && ((CallExpression) predicate).getArguments().size() == 2) { + RowExpression argument0 = ((CallExpression) predicate).getArguments().get(0); + RowExpression argument1 = ((CallExpression) predicate).getArguments().get(1); + + // First check if arguments refer to different sides of join + Set argument0Vars = extractUnique(argument0); + Set argument1Vars = extractUnique(argument1); + if (!((leftVariables.containsAll(argument0Vars) && rightVariables.containsAll(argument1Vars)) + || (rightVariables.containsAll(argument0Vars) && leftVariables.containsAll(argument1Vars)))) { + // Arguments have a mix of join sides, use this predicate as a filter + joinFilters.add(predicate); + continue; + } + + // Next, check to see if first argument refers to left side and second argument to the right side + // If not, swap the arguments + if (leftVariables.containsAll(argument1Vars)) { + RowExpression temp = argument1; + argument1 = argument0; + argument0 = temp; + } + + // Next, check if we need to create new assignments for complex equi-join clauses + // E.g. leftVar = ADD(rightVar1, rightVar2) + if (!(argument0 instanceof VariableReferenceExpression)) { + VariableReferenceExpression newLeft = variableAllocator.newVariable(argument0); + newLeftAssignments.put(newLeft, argument0); + argument0 = newLeft; + } + + if (!(argument1 instanceof VariableReferenceExpression)) { + VariableReferenceExpression newRight = variableAllocator.newVariable(argument1); + newRightAssignments.put(newRight, argument1); + argument1 = newRight; + } + + joinClauses.add(new EquiJoinClause((VariableReferenceExpression) argument0, (VariableReferenceExpression) argument1)); + } + else { + joinFilters.add(predicate); + } + } + + return new JoinCondition(joinClauses.build(), joinFilters.build(), newLeftAssignments.build(), newRightAssignments.build()); } - private static EquiJoinClause toEquiJoinClause(CallExpression equality, Set leftVariables, VariableAllocator variableAllocator) + @VisibleForTesting + static class JoinCondition { - checkArgument(equality.getArguments().size() == 2, "Unexpected number of arguments in binary operator equals"); - VariableReferenceExpression leftVariable = (VariableReferenceExpression) equality.getArguments().get(0); - VariableReferenceExpression rightVariable = (VariableReferenceExpression) equality.getArguments().get(1); - EquiJoinClause equiJoinClause = new EquiJoinClause(leftVariable, rightVariable); - return leftVariables.contains(leftVariable) ? equiJoinClause : equiJoinClause.flip(); + List joinClauses; + List joinFilters; + Map newLeftAssignments; + Map newRightAssignments; + + public JoinCondition(List joinClauses, List joinFilters, + Map left, Map right) + { + this.joinClauses = joinClauses; + this.joinFilters = joinFilters; + this.newLeftAssignments = left; + this.newRightAssignments = right; + } + + public List getJoinClauses() + { + return joinClauses; + } + + public List getJoinFilters() + { + return joinFilters; + } + + public Map getNewLeftAssignments() + { + return newLeftAssignments; + } + + public Map getNewRightAssignments() + { + return newRightAssignments; + } } private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) @@ -473,14 +596,19 @@ static class MultiJoinNode { // Use a linked hash set to ensure optimizer is deterministic private final CanonicalJoinNode node; + private final Assignments assignments; - public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List outputVariables) + public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List outputVariables, + Assignments assignments) { checkArgument(sources.size() > 1, "sources size is <= 1"); requireNonNull(sources, "sources is null"); requireNonNull(filter, "filter is null"); requireNonNull(outputVariables, "outputVariables is null"); + requireNonNull(assignments, "assignments is null"); + + this.assignments = assignments; // Plan node id doesn't matter here as we don't use this in planner this.node = new CanonicalJoinNode( new PlanNodeId(""), @@ -489,9 +617,6 @@ public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List ImmutableSet.of(), ImmutableSet.of(filter), outputVariables); - - List inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableList()); - checkArgument(inputVariables.containsAll(outputVariables), "inputs do not contain all output variables"); } public RowExpression getFilter() @@ -509,6 +634,11 @@ public List getOutputVariables() return node.getOutputVariables(); } + public Assignments getAssignments() + { + return assignments; + } + public static Builder builder() { return new Builder(); @@ -530,25 +660,38 @@ public boolean equals(Object obj) MultiJoinNode other = (MultiJoinNode) obj; return getSources().equals(other.getSources()) && ImmutableSet.copyOf(extractConjuncts(getFilter())).equals(ImmutableSet.copyOf(extractConjuncts(other.getFilter()))) - && getOutputVariables().equals(other.getOutputVariables()); + && getOutputVariables().equals(other.getOutputVariables()) + && getAssignments().equals(other.getAssignments()); } - static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) + @Override + public String toString() + { + return "MultiJoinNode{" + + "node=" + node + + ", assignments=" + assignments + + '}'; + } + + static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) { // the number of sources is the number of joins + 1 - return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, functionResolution, determinismEvaluator).toMultiJoinNode(); + return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, handleComplexEquiJoins, functionResolution, determinismEvaluator).toMultiJoinNode(); } private static class JoinNodeFlattener { private final LinkedHashSet sources = new LinkedHashSet<>(); - private final List filters = new ArrayList<>(); + private final Assignments intermediateAssignments; + private final boolean handleComplexEquiJoins; + private List filters = new ArrayList<>(); private final List outputVariables; private final FunctionResolution functionResolution; private final DeterminismEvaluator determinismEvaluator; private final Lookup lookup; - JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) + JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, + DeterminismEvaluator determinismEvaluator) { requireNonNull(node, "node is null"); checkState(node.getType() == INNER, "join type must be INNER"); @@ -556,13 +699,74 @@ private static class JoinNodeFlattener this.lookup = requireNonNull(lookup, "lookup is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null"); - flattenNode(node, sourceLimit); + this.handleComplexEquiJoins = handleComplexEquiJoins; + + Map intermediateAssignments = new HashMap<>(); + flattenNode(node, sourceLimit, intermediateAssignments); + + // We resolve the intermediate assignments to only inputs of the flattened join node + ImmutableSet inputVariables = sources.stream().flatMap(s -> s.getOutputVariables().stream()).collect(toImmutableSet()); + this.intermediateAssignments = resolveAssignments(intermediateAssignments, inputVariables); + rewriteFilterWithInlinedAssignments(this.intermediateAssignments); + } + + private Assignments resolveAssignments(Map assignments, Set availableVariables) + { + HashSet resolvedVariables = new HashSet<>(); + for (VariableReferenceExpression variable : assignments.keySet()) { + resolveVariable(variable, resolvedVariables, assignments, availableVariables); + } + + return Assignments.builder().putAll(assignments).build(); + } + + private void resolveVariable(VariableReferenceExpression variable, HashSet resolvedVariables, Map assignments, Set availableVariables) + { + RowExpression expression = assignments.get(variable); + + Sets.SetView variablesToResolve = Sets.difference(Sets.difference(extractUnique(expression), availableVariables), resolvedVariables); + if (variablesToResolve.isEmpty()) { + resolvedVariables.add(variable); + return; + } + + variablesToResolve.forEach(variableToResolve -> resolveVariable(variableToResolve, resolvedVariables, assignments, availableVariables)); + // Modify the assignments to replace the variables with the resolved expressions + assignments.put(variable, replaceExpression(expression, assignments)); + // Mark the variable as resolved + resolvedVariables.add(variable); } - private void flattenNode(PlanNode node, int limit) + private void rewriteFilterWithInlinedAssignments(Assignments assignments) + { + ImmutableList.Builder modifiedFilters = ImmutableList.builder(); + filters.forEach(filter -> modifiedFilters.add(replaceExpression(filter, assignments.getMap()))); + filters = modifiedFilters.build(); + } + + private void flattenNode(PlanNode node, int limit, Map assignmentsBuilder) { PlanNode resolved = lookup.resolve(node); + if (resolved instanceof ProjectNode) { + ProjectNode projectNode = (ProjectNode) resolved; + // A ProjectNode could be 'hiding' a join source by building an assignment of a complex equi-join criteria like `left.key = right1.key1 + right1.key2` + // We open up the join space by tracking the assignments from this Project node; these will be inlined into the overall filters once we finish + // traversing the join graph + // We only do this if the ProjectNode assignments are deterministic + if (handleComplexEquiJoins && lookup.resolve(projectNode.getSource()) instanceof JoinNode && + projectNode.getAssignments().getExpressions().stream().allMatch(determinismEvaluator::isDeterministic)) { + //We keep track of only the non-identity assignments since these are the ones that will be inlined into the overall filters + assignmentsBuilder.putAll(getNonIdentityAssignments(projectNode.getAssignments())); + flattenNode(projectNode.getSource(), limit, assignmentsBuilder); + } + else { + sources.add(node); + } + return; + } + // (limit - 2) because you need to account for adding left and right side if (!(resolved instanceof JoinNode) || (sources.size() > (limit - 2))) { sources.add(node); @@ -576,8 +780,8 @@ private void flattenNode(PlanNode node, int limit) } // we set the left limit to limit - 1 to account for the node on the right - flattenNode(joinNode.getLeft(), limit - 1); - flattenNode(joinNode.getRight(), limit); + flattenNode(joinNode.getLeft(), limit - 1, assignmentsBuilder); + flattenNode(joinNode.getRight(), limit, assignmentsBuilder); joinNode.getCriteria().stream() .map(criteria -> toRowExpression(criteria, functionResolution)) .forEach(filters::add); @@ -586,7 +790,35 @@ private void flattenNode(PlanNode node, int limit) MultiJoinNode toMultiJoinNode() { - return new MultiJoinNode(sources, and(filters), outputVariables); + ImmutableSet inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableSet()); + + // We could have some output variables that were possibly generated from intermediate assignments + // For each of these variables, use the intermediate assignments to replace this variable with the set of input variables it uses + + // Additionally, we build an overall set of assignments for the reordered Join node - this is used to add a wrapper Project over the updated output variables + // We do this to satisfy the invariant that the rewritten Join node must produce the same output variables as the input Join node + ImmutableSet.Builder updatedOutputVariables = ImmutableSet.builder(); + Assignments.Builder overallAssignments = Assignments.builder(); + boolean nonIdentityAssignmentsFound = false; + + for (VariableReferenceExpression outputVariable : outputVariables) { + if (inputVariables.contains(outputVariable)) { + overallAssignments.put(outputVariable, outputVariable); + updatedOutputVariables.add(outputVariable); + continue; + } + + checkState(intermediateAssignments.getMap().containsKey(outputVariable), + "Output variable [%s] not found in input variables or in intermediate assignments", outputVariable); + nonIdentityAssignmentsFound = true; + overallAssignments.put(outputVariable, intermediateAssignments.get(outputVariable)); + updatedOutputVariables.addAll(extractUnique(intermediateAssignments.get(outputVariable))); + } + + return new MultiJoinNode(sources, + and(filters), + updatedOutputVariables.build().asList(), + nonIdentityAssignmentsFound ? overallAssignments.build() : Assignments.of()); } } @@ -595,6 +827,7 @@ static class Builder private List sources; private RowExpression filter; private List outputVariables; + private Assignments assignments = Assignments.of(); public Builder setSources(PlanNode... sources) { @@ -608,6 +841,12 @@ public Builder setFilter(RowExpression filter) return this; } + public Builder setAssignments(Assignments assignments) + { + this.assignments = assignments; + return this; + } + public Builder setOutputVariables(VariableReferenceExpression... outputVariables) { this.outputVariables = ImmutableList.copyOf(outputVariables); @@ -616,7 +855,7 @@ public Builder setOutputVariables(VariableReferenceExpression... outputVariables public MultiJoinNode build() { - return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputVariables); + return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputVariables, assignments); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java index 6a7c56fb031ad..c175b6d0ed1b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import java.util.Collection; @@ -51,6 +52,17 @@ public static boolean isIdentity(Assignments assignments) return true; } + public static Map getNonIdentityAssignments(Assignments assignments) + { + ImmutableMap.Builder nonIdentityAssignments = ImmutableMap.builder(); + for (Map.Entry assignment : assignments.entrySet()) { + if (!assignment.getKey().equals(assignment.getValue())) { + nonIdentityAssignments.put(assignment); + } + } + return nonIdentityAssignments.build(); + } + public static boolean isIdentity(Assignments assignments, VariableReferenceExpression output) { RowExpression value = assignments.get(output); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index f98d44439dec0..b73297778d207 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -248,7 +248,8 @@ public void testDefaults() .setPullUpExpressionFromLambdaEnabled(false) .setRewriteConstantArrayContainsToInEnabled(false) .setUseHBOForScaledWriters(false) - .setRemoveRedundantCastToVarcharInJoin(true)); + .setRemoveRedundantCastToVarcharInJoin(true) + .setHandleComplexEquiJoins(false)); } @Test @@ -445,6 +446,7 @@ public void testExplicitPropertyMappings() .put("optimizer.rewrite-constant-array-contains-to-in", "true") .put("optimizer.use-hbo-for-scaled-writers", "true") .put("optimizer.remove-redundant-cast-to-varchar-in-join", "false") + .put("optimizer.handle-complex-equi-joins", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -638,7 +640,8 @@ public void testExplicitPropertyMappings() .setPullUpExpressionFromLambdaEnabled(true) .setRewriteConstantArrayContainsToInEnabled(true) .setUseHBOForScaledWriters(true) - .setRemoveRedundantCastToVarcharInJoin(false); + .setRemoveRedundantCastToVarcharInJoin(false) + .setHandleComplexEquiJoins(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java index 63dea91103d9f..4e386ab85427f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java @@ -439,6 +439,7 @@ public void testNonPushedDownJoinFilterRemoval() "SELECT 1 FROM part t0, part t1, part t2 " + "WHERE t0.partkey = t1.partkey AND t0.partkey = t2.partkey " + "AND t0.size + t1.size = t2.size", + noJoinReordering(), anyTree( join( INNER, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index 9d0f7175ad91d..062cca4bbc5c9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -100,7 +100,7 @@ /** * RowExpression visitor which verifies if given expression (actual) is matching other RowExpression given as context (expected). */ -final class RowExpressionVerifier +public final class RowExpressionVerifier extends AstVisitor { // either use variable or input reference for symbol mapping @@ -110,7 +110,7 @@ final class RowExpressionVerifier private final FunctionResolution functionResolution; private final Set lambdaArguments; - RowExpressionVerifier(SymbolAliases symbolAliases, Metadata metadata, Session session) + public RowExpressionVerifier(SymbolAliases symbolAliases, Metadata metadata, Session session) { this.symbolAliases = requireNonNull(symbolAliases, "symbolLayout is null"); this.metadata = requireNonNull(metadata, "metadata is null"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java index d60ea7b95e8a7..c6f67a02c2a05 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -15,46 +15,69 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.CachingCostProvider; import com.facebook.presto.cost.CachingStatsProvider; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.CostProvider; import com.facebook.presto.cost.PlanCostEstimate; import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.LogicalPropertiesProvider; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.DeterminismEvaluator; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.assertions.RowExpressionVerifier; +import com.facebook.presto.sql.planner.assertions.SymbolAliases; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator; +import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.JoinCondition; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.MultiJoinNode; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.Arrays; import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; +import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.planner.optimizations.JoinNodeUtils.toRowExpression; +import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; public class TestJoinEnumerator { @@ -62,14 +85,20 @@ public class TestJoinEnumerator private Metadata metadata; private DeterminismEvaluator determinismEvaluator; private FunctionResolution functionResolution; + private PlanBuilder planBuilder; + private TestingRowExpressionTranslator rowExpressionTranslator; + private Session session; @BeforeClass public void setUp() { - queryRunner = new LocalQueryRunner(testSessionBuilder().build()); + session = testSessionBuilder().build(); + queryRunner = new LocalQueryRunner(session); metadata = queryRunner.getMetadata(); determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata); functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); + rowExpressionTranslator = new TestingRowExpressionTranslator(metadata); } @AfterClass(alwaysRun = true) @@ -109,7 +138,8 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() MultiJoinNode multiJoinNode = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(p.values(a1), p.values(b1))), TRUE_CONSTANT, - ImmutableList.of(a1, b1)); + ImmutableList.of(a1, b1), + Assignments.of()); JoinEnumerator joinEnumerator = new JoinEnumerator( new CostComparator(1, 1, 1), multiJoinNode.getFilter(), @@ -122,6 +152,94 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() assertEquals(actual.getCost(), PlanCostEstimate.infinite()); } + @Test + public void testJoinClauseAndFilterInference() + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put("a", BIGINT); + builder.put("b", BIGINT); + builder.put("c", BIGINT); + builder.put("d", BIGINT); + Map variableMap = builder.build(); + VariableReferenceExpression a = variable("a", variableMap.get("a")); + VariableReferenceExpression b = variable("b", variableMap.get("b")); + VariableReferenceExpression c = variable("c", variableMap.get("c")); + VariableReferenceExpression d = variable("d", variableMap.get("d")); + + SymbolAliases.Builder newAliases = SymbolAliases.builder(); + newAliases.put("A", new SymbolReference("a")); + newAliases.put("B", new SymbolReference("b")); + newAliases.put("C", new SymbolReference("c")); + newAliases.put("D", new SymbolReference("d")); + SymbolAliases symbolAliases = newAliases.build(); + + // Simple join predicates on variable references + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B", null); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b", "c = d"), ImmutableSet.of(a, c), ImmutableSet.of(b, d), "A = B AND C = D", null); + // Complex join predicate - All variables must be from one join side to have the predicate be an equi-join clause + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C", null); + // Left and right side designation can be switched + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C", null); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C + 1", null); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C + 1", null); + // If a predicate has a mix of variables from left & right sides, the predicate is treated as a filter + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = c"), ImmutableSet.of(a), ImmutableSet.of(b, c), null, "A + B = C"); + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = 1"), ImmutableSet.of(a), ImmutableSet.of(b), null, "A + B = 1"); + // Test with multiple equi-join conditions and filters + assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = ABS(b)", "a = ceil(b-c)", "b = c + 10"), + ImmutableSet.of(a), ImmutableSet.of(b, c), "A = abs(B) AND A = ceil(B-C)", "B = C + 10"); + } + + private List toRowExpressionList(Map variableTypeMap, String... predicates) + { + return Arrays.stream(predicates) + .map(p -> rowExpressionTranslator.translate(p, variableTypeMap)) + .collect(Collectors.toList()); + } + + private void assertJoinCondition(SymbolAliases symbolAliases, List joinPredicates, Set leftVariables, + Set rightVariables, String expectedEquiJoinClause, String expectedJoinFilter) + { + RowExpressionVerifier verifier = new RowExpressionVerifier(symbolAliases, metadata, session); + JoinEnumerator joinEnumerator = new JoinEnumerator( + new CostComparator(1, 1, 1), + TRUE_CONSTANT, + createContext(), + determinismEvaluator, + functionResolution, + metadata); + + JoinCondition joinConditions = joinEnumerator.extractJoinConditions(joinPredicates, + leftVariables, rightVariables, new VariableAllocator()); + + Optional equiJoinExpressionInlined = joinConditions.getJoinClauses().stream() + .map(criteria -> toRowExpression(criteria, functionResolution)) + // We may have made left or right assignments to build the equi join clause + // We inline these assignments for building the equi join clause to verify + .map(expression -> replaceExpression(expression, joinConditions.getNewLeftAssignments())) + .map(expression -> replaceExpression(expression, joinConditions.getNewRightAssignments())) + .reduce(LogicalRowExpressions::and); + + if (equiJoinExpressionInlined.isPresent()) { + assertNotNull(expectedEquiJoinClause); + assertTrue(verifier.process(expression(expectedEquiJoinClause), equiJoinExpressionInlined.get())); + } + else { + assertNull(expectedEquiJoinClause); + } + + Optional joinFilter = joinConditions.getJoinFilters().stream() + .reduce(LogicalRowExpressions::and); + + if (joinFilter.isPresent()) { + assertNotNull(expectedJoinFilter); + assertTrue(verifier.process(expression(expectedJoinFilter), joinFilter.get())); + } + else { + assertNull(expectedJoinFilter); + } + } + private Rule.Context createContext() { PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java index 6ba4bc4f7c9b2..e7f96e3151763 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java @@ -15,11 +15,15 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.MultiJoinNode; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -39,7 +43,9 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.expressions.LogicalRowExpressions.and; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.MultiJoinNode.toMultiJoinNode; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; @@ -57,13 +63,15 @@ public class TestJoinNodeFlattener private FunctionResolution functionResolution; private LocalQueryRunner queryRunner; + private FunctionAndTypeResolver functionAndTypeResolver; @BeforeClass public void setUp() { queryRunner = new LocalQueryRunner(testSessionBuilder().build()); determinismEvaluator = new RowExpressionDeterminismEvaluator(queryRunner.getMetadata()); - functionResolution = new FunctionResolution(queryRunner.getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver()); + functionAndTypeResolver = queryRunner.getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver(); + functionResolution = new FunctionResolution(functionAndTypeResolver); } @AfterClass(alwaysRun = true) @@ -86,7 +94,7 @@ public void testDoesNotAllowOuterJoin() ImmutableList.of(equiJoinClause(a1, b1)), ImmutableList.of(a1, b1), Optional.empty()); - toMultiJoinNode(outerJoin, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator); + toMultiJoinNode(outerJoin, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator); } @Test @@ -116,7 +124,7 @@ public void testDoesNotConvertNestedOuterJoins() .setSources(leftJoin, valuesC).setFilter(createEqualsExpression(a1, c1)) .setOutputVariables(a1, b1, c1) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator), expected); } @Test @@ -149,7 +157,7 @@ public void testRetainsOutputSymbols() .setFilter(and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1))) .setOutputVariables(a1, b1) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator), expected); } @Test @@ -206,8 +214,9 @@ public void testCombinesCriteriaAndFilters() MultiJoinNode expected = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(valuesA, valuesB, valuesC)), and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1), bcFilter, abcFilter), - ImmutableList.of(a1, b1, b2, c1, c2)); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); + ImmutableList.of(a1, b1, b2, c1, c2), + Assignments.builder().build()); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator), expected); } @Test @@ -258,7 +267,7 @@ public void testConvertsBushyTrees() .setFilter(and(createEqualsExpression(a1, b1), createEqualsExpression(a1, c1), createEqualsExpression(d1, e1), createEqualsExpression(d2, e2), createEqualsExpression(b1, e1))) .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, false, functionResolution, determinismEvaluator), expected); } @Test @@ -311,10 +320,128 @@ public void testMoreThanJoinLimit() .setFilter(and(createEqualsExpression(a1, c1), createEqualsExpression(b1, e1))) .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), 2, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), 2, true, functionResolution, determinismEvaluator), expected); + } + + @Test + public void testProjectNodesBetweenJoinNodesAreFlattenedForComplexEquiJoins() + { + PlanBuilder p = planBuilder(); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression c1 = p.variable("C1"); + VariableReferenceExpression d1 = p.variable("D1"); + VariableReferenceExpression e1 = p.variable("E1"); + VariableReferenceExpression sum = p.variable("SUM"); + VariableReferenceExpression rename = p.variable("RENAME"); + VariableReferenceExpression renamePlusSum = p.variable("RENAME_PLUS_SUM"); + + ValuesNode valuesA = p.values(a1); + ValuesNode valuesB = p.values(b1); + ValuesNode valuesC = p.values(c1); + ValuesNode valuesD = p.values(d1); + ValuesNode valuesE = p.values(e1); + Assignments assignmentA1PlusB1 = Assignments.builder().put(sum, createAddExpression(a1, b1)).build(); + Assignments assignmentRenameC1 = Assignments.builder().put(rename, c1).build(); + Assignments assignmentRenamePlusSum = Assignments.builder().put(renamePlusSum, createAddExpression(rename, sum)).build(); + + ProjectNode projectOverJoin3 = p.project(assignmentRenamePlusSum, p.join( + INNER, + p.project(assignmentA1PlusB1, p.join(// projectOverJoin1 + INNER, + valuesA, + valuesB, + ImmutableList.of(equiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty())), + p.project(assignmentRenameC1, p.join(// projectOverJoin2 + INNER, + valuesC, + valuesD, + ImmutableList.of(equiJoinClause(c1, d1)), + ImmutableList.of(c1), + Optional.empty())), + ImmutableList.of(equiJoinClause(sum, rename)), + ImmutableList.of(sum, rename), + Optional.empty())); + + JoinNode topMostJoinNode = p.join( + INNER, + valuesE, + projectOverJoin3, + ImmutableList.of(equiJoinClause(e1, renamePlusSum)), + ImmutableList.of(e1, renamePlusSum), + Optional.empty()); + + MultiJoinNode expected = MultiJoinNode.builder() + .setSources(valuesA, valuesB, valuesC, valuesD, valuesE) + .setFilter(and(createEqualsExpression(a1, b1), + createEqualsExpression(c1, d1), + createEqualsExpression(createAddExpression(a1, b1), c1), + createEqualsExpression(e1, createAddExpression(c1, createAddExpression(a1, b1))))) + .setAssignments(Assignments.of(e1, e1, renamePlusSum, createAddExpression(c1, createAddExpression(a1, b1)))) + .setOutputVariables(e1, c1, a1, b1) + .build(); + MultiJoinNode actual = toMultiJoinNode(topMostJoinNode, noLookup(), 5, /*handleComplexEquiJoins*/ true, functionResolution, determinismEvaluator); + assertEquals(actual, expected); + + // Negative test - when handleComplexEquiJoins = false, we have a split join space; the ProjectNodes are not flattened + expected = MultiJoinNode.builder() + .setSources(valuesE, projectOverJoin3) + .setFilter(createEqualsExpression(e1, renamePlusSum)) + .setAssignments(Assignments.of()) + .setOutputVariables(e1, renamePlusSum) + .build(); + + assertEquals(toMultiJoinNode(topMostJoinNode, noLookup(), 5, /*handleComplexEquiJoins*/ false, functionResolution, determinismEvaluator), expected); + } + + @Test + public void testProjectNodesWithNonDeterministicAssignmentsAreNotFlattenedForComplexEquiJoins() + { + PlanBuilder p = planBuilder(); + VariableReferenceExpression a1 = p.variable("A1"); + VariableReferenceExpression b1 = p.variable("B1"); + VariableReferenceExpression c1 = p.variable("C1"); + VariableReferenceExpression randomPlusSum = p.variable("RANDOM_PLUS_SUM"); + Assignments nonDeterministicAssignment = Assignments.builder().put(randomPlusSum, createAddExpression(createRandomExpression(), createAddExpression(a1, b1))).build(); + + ValuesNode valuesA = p.values(a1); + ValuesNode valuesB = p.values(b1); + ValuesNode valuesC = p.values(c1); + + ProjectNode projectWithNonDeterministicAssignment = p.project(nonDeterministicAssignment, p.join( + INNER, + valuesA, + valuesB, + ImmutableList.of(equiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty())); + + JoinNode joinNodeToFlatten = p.join( + INNER, + projectWithNonDeterministicAssignment, + valuesC, + ImmutableList.of(equiJoinClause(randomPlusSum, c1)), + ImmutableList.of(), + Optional.empty()); + + MultiJoinNode expected = MultiJoinNode.builder() + .setSources(projectWithNonDeterministicAssignment, valuesC) + .setFilter(createEqualsExpression(randomPlusSum, c1)) + .setAssignments(Assignments.of()) + .setOutputVariables() + .build(); + + assertEquals(toMultiJoinNode(joinNodeToFlatten, noLookup(), 5, /*handleComplexEquiJoins*/ true, functionResolution, determinismEvaluator), expected); } - private RowExpression createEqualsExpression(VariableReferenceExpression left, VariableReferenceExpression right) + private CallExpression createRandomExpression() + { + return call("random", functionAndTypeResolver.lookupFunction("random", fromTypes()), DOUBLE); + } + + private RowExpression createEqualsExpression(RowExpression left, RowExpression right) { return call( OperatorType.EQUAL.name(), @@ -323,6 +450,15 @@ private RowExpression createEqualsExpression(VariableReferenceExpression left, V ImmutableList.of(left, right)); } + private RowExpression createAddExpression(RowExpression left, RowExpression right) + { + return call( + OperatorType.ADD.name(), + functionResolution.arithmeticFunction(OperatorType.ADD, left.getType(), right.getType()), + BIGINT, + ImmutableList.of(left, right)); + } + private EquiJoinClause equiJoinClause(VariableReferenceExpression variable1, VariableReferenceExpression variable2) { return new EquiJoinClause(variable1, variable2); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java index c29c366672392..79fdcfb55a04e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.Session; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.PlanNodeStatsEstimate; @@ -22,6 +23,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; @@ -32,12 +34,14 @@ import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Optional; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; +import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -49,8 +53,13 @@ import static com.facebook.presto.metadata.FunctionAndTypeManager.qualifyObjectName; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.BROADCAST; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; @@ -59,6 +68,7 @@ import static com.facebook.presto.sql.relational.Expressions.variable; public class TestReorderJoins + extends BasePlanTest { private RuleTester tester; private FunctionResolution functionResolution; @@ -67,6 +77,37 @@ public class TestReorderJoins private static final ImmutableList> TWO_ROWS = ImmutableList.of(ImmutableList.of(), ImmutableList.of()); private static final QualifiedName RANDOM = QualifiedName.of("random"); + @DataProvider + public static Object[][] tableSpecificationPermutations() + { + return new Object[][] { + {"supplier s, partsupp ps, customer c, orders o"}, + {"supplier s, partsupp ps, orders o, customer c"}, + {"supplier s, customer c, partsupp ps, orders o"}, + {"supplier s, customer c, orders o, partsupp ps"}, + {"supplier s, orders o, partsupp ps, customer c"}, + {"supplier s, orders o, customer c, partsupp ps"}, + {"partsupp ps, supplier s, customer c, orders o"}, + {"partsupp ps, supplier s, orders o, customer c"}, + {"partsupp ps, customer c, supplier s, orders o"}, + {"partsupp ps, customer c, orders o, supplier s"}, + {"partsupp ps, orders o, supplier s, customer c"}, + {"partsupp ps, orders o, customer c, supplier s"}, + {"customer c, supplier s, partsupp ps, orders o"}, + {"customer c, supplier s, orders o, partsupp ps"}, + {"customer c, partsupp ps, supplier s, orders o"}, + {"customer c, partsupp ps, orders o, supplier s"}, + {"customer c, orders o, supplier s, partsupp ps"}, + {"customer c, orders o, partsupp ps, supplier s"}, + {"orders o, supplier s, partsupp ps, customer c"}, + {"orders o, supplier s, customer c, partsupp ps"}, + {"orders o, partsupp ps, supplier s, customer c"}, + {"orders o, partsupp ps, customer c, supplier s"}, + {"orders o, customer c, supplier s, partsupp ps"}, + {"orders o, customer c, partsupp ps, supplier s"} + }; + } + @BeforeClass public void setUp() { @@ -74,7 +115,8 @@ public void setUp() ImmutableList.of(), ImmutableMap.of( JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name(), - JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name()), + JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name(), + HANDLE_COMPLEX_EQUI_JOINS, "true"), Optional.of(4)); this.functionResolution = new FunctionResolution(tester.getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver()); } @@ -550,6 +592,159 @@ public void testReorderAndReplicate() values(ImmutableMap.of("A1", 0)))); } + /** + * This test asserts that join re-ordering works as expected for complex equi join clauses ('s.acctbal = c.acctbal + o.totalprice') + * and works irrespective of the order in which tables are specified in the FROM clause + * + * @param tableSpecificationOrder The table specification order + */ + @Test(dataProvider = "tableSpecificationPermutations") + public void testComplexEquiJoinCriteria(String tableSpecificationOrder) + { + // For a full connected join graph, we don't see any CrossJoins + String query = "select 1 from " + tableSpecificationOrder + " where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal = c.acctbal + o.totalprice"; + PlanMatchPattern expectedPlan = + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree(tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("SUM", "S_ACCTBAL")), + anyTree( + project(ImmutableMap.of("SUM", expression("C_ACCTBAL + O_TOTALPRICE")), + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey"))))))); + assertPlan(query, tester.getSession(), expectedPlan); + + // The plan is identical to the plan for the fully spelled out version of the Join + String fullQuery = "select 1 from (supplier s inner join partsupp ps on s.suppkey = ps.suppkey) inner join (orders o inner join customer c on c.custkey = o.custkey) " + + " on s.acctbal = c.acctbal + o.totalprice"; + assertPlan(fullQuery, tester.getSession(), expectedPlan); + } + + @Test + public void testComplexEquiJoinCriteriaForDisjointGraphs() + { + // If the join clause is written with the Left/Right side referring to both sides of a Join node, an equi-join condition cannot be inferred + // and the join space is broken up. Hence, we observe a CrossJoin node + assertPlan("select 1 from supplier s, partsupp ps, customer c, orders o where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal - c.acctbal = o.totalprice", tester.getSession(), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("C_CUSTKEY", "O_CUSTKEY"), equiJoinClause("SUBTRACT", "O_TOTALPRICE")), + anyTree( + project(ImmutableMap.of("SUBTRACT", expression("S_ACCTBAL - C_ACCTBAL")), + join(INNER, + ImmutableList.of(), //CrossJoin + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree(tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey")))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice")))))); + + // The table specification order determines the join order for such cases + // With the below table specification order, the planner adds the complex equi-join condition as a FilterNode on top of a JoinNode + assertPlan("select 1 from orders o, customer c, supplier s, partsupp ps where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal - c.acctbal = o.totalprice", tester.getSession(), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree( + tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), + anyTree( + filter("O_TOTALPRICE = S_ACCTBAL - C_ACCTBAL", + join(INNER, + ImmutableList.of(), //CrossJoin + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree(tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey"))))))))); + + // For sub-graphs that are fully connected, join-reordering works with complex predicates as expected + // The rest of the join graph is connected using a CrossJoin + assertPlan("select 1 " + + "from orders o, customer c, supplier s, partsupp ps, part p " + + "where s.suppkey = ps.suppkey " + + " and c.custkey = o.custkey " + + " and s.acctbal = c.acctbal + o.totalprice" + + " and ps.partkey - p.partkey = 0 ", + tester.getSession(), + anyTree( + filter("PS_PARTKEY - P_PARTKEY = 0", + join(INNER, + ImmutableList.of(), // CrossJoin + join(INNER, + ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), + anyTree( + tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey", "PS_PARTKEY", "partkey"))), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("SUM", "S_ACCTBAL")), + anyTree( + project(ImmutableMap.of("SUM", expression("C_ACCTBAL + O_TOTALPRICE")), + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), + anyTree( + tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey")))))), + anyTree( + tableScan("part", ImmutableMap.of("P_PARTKEY", "partkey"))))))); + } + + @Test + public void testComplexEquiJoinCriteriaForJoinsWithUSINGClause() + { + // Projecting all the columns from the sources while joining tables with a USING clause introduces intermediate Project Nodes in the join graph + // This breaks join-reordering, and we get table-specification ordering for the Join graph + String usingQueryWithStarProjection = "select * from orders join lineitem USING (orderkey) join customer USING (custkey)"; + + Session session = Session.builder(tester.getSession()).setSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, "false").build(); + assertPlan(usingQueryWithStarProjection, session, + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("L_ORDERKEY", "O_ORDERKEY")), + anyTree( + tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey"))), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_ORDERKEY", "orderkey"))))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey")))))); + + // With HANDLE_COMPLEX_EQUI_JOINS turned on, the intermediate Project nodes are handled and join-reordering works as expected + session = Session.builder(tester.getSession()).setSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, "true").build(); + assertPlan(usingQueryWithStarProjection, session, + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("L_ORDERKEY", "O_ORDERKEY")), + anyTree( + tableScan("lineitem", ImmutableMap.of("L_ORDERKEY", "orderkey"))), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), + anyTree( + tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_ORDERKEY", "orderkey"))), + anyTree( + tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey")))))))); + } + private RuleAssert assertReorderJoins() { return tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), tester.getMetadata())); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java index 369b0277abe66..374310940b85b 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java @@ -24,6 +24,7 @@ import com.google.common.collect.Iterables; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOINS_NOT_NULL_INFERENCE_STRATEGY; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -470,6 +471,29 @@ public void testJoinWithComplexExpressions3() "SELECT SUM(custkey) FROM lineitem JOIN orders ON lineitem.orderkey + 1 = orders.orderkey + 1", // H2 takes a million years because it can't join efficiently on a non-indexed field/expression "SELECT SUM(custkey) FROM lineitem JOIN orders ON lineitem.orderkey = orders.orderkey "); + + Session handleComplexEquiJoins = Session.builder(getSession()) + .setSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, "true") + .build(); + + assertQueryWithSameQueryRunner( + handleComplexEquiJoins, + "select c.custkey, ps.partkey, s.suppkey, o.orderkey " + + "from customer c, " + + " partsupp ps, " + + " orders o, " + + " supplier s " + + "where s.suppkey = ps.suppkey " + + " and c.custkey = o.custkey " + + " and s.nationkey + ps.partkey = c.nationkey " + + "order by c.custkey, ps.partkey, s.suppkey, o.orderkey", + noJoinReordering(), + "select c.custkey, ps.partkey, s.suppkey, o.orderkey " + + "from (customer c inner join orders o ON c.custkey = o.custkey) " + + " inner join " + + " (partsupp ps inner join supplier s ON s.suppkey = ps.suppkey) " + + " on s.nationkey + ps.partkey = c.nationkey " + + "order by c.custkey, ps.partkey, s.suppkey, o.orderkey"); } @Test diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index dc32f4fff33c4..6892d8aa94d66 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -173,6 +173,13 @@ protected void assertQueryWithSameQueryRunner(Session actualSession, @Language(" { QueryAssertions.assertQuery(queryRunner, actualSession, query, queryRunner, expectedSession, query, false, false); } + + protected void assertQueryWithSameQueryRunner(Session actualSession, @Language("SQL") String actual, Session expectedSession, @Language("SQL") String expected) + { + checkArgument(!actual.equals(expected)); + QueryAssertions.assertQuery(queryRunner, actualSession, actual, queryRunner, expectedSession, expected, false, false); + } + protected void assertQuery(Session session, @Language("SQL") String actual, @Language("SQL") String expected) { QueryAssertions.assertQuery(queryRunner, session, actual, expectedQueryRunner, expected, false, false);