diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index f3abbee92778c..ebec09098fe37 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -790,7 +790,7 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext> context) { - ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, true); ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); return new UnionNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping)); } @@ -798,7 +798,7 @@ public PlanNode visitUnion(UnionNode node, RewriteContext> context) { - ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, false); ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); return new IntersectNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping)); } @@ -806,17 +806,17 @@ public PlanNode visitIntersect(IntersectNode node, RewriteContext> context) { - ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, false); ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); return new ExceptNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping)); } - private ListMultimap rewriteSetOperationVariableMapping(SetOperationNode node, RewriteContext> context) + private ListMultimap rewriteSetOperationVariableMapping(SetOperationNode node, RewriteContext> context, boolean pruneUnreferencedOutput) { // Find out which output variables we need to keep ImmutableListMultimap.Builder rewrittenVariableMappingBuilder = ImmutableListMultimap.builder(); for (VariableReferenceExpression variable : node.getOutputVariables()) { - if (context.get().contains(variable)) { + if (context.get().contains(variable) || !pruneUnreferencedOutput) { rewrittenVariableMappingBuilder.putAll( variable, node.getVariableMapping().get(variable)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java index 33e69c9670302..3b65ee64f1cf0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java @@ -19,11 +19,13 @@ import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.assertions.OptimizerAssert; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.plan.WindowNode; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -32,6 +34,8 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.except; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.intersect; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; @@ -89,6 +93,52 @@ public void windowNodePruning() values("user_uuid"))))); } + @Test + public void testIntersectNodePruning() + { + assertRuleApplication() + .on(p -> + p.output(ImmutableList.of("regionkey"), ImmutableList.of(p.variable("regionkey_16")), + p.project(Assignments.of(p.variable("regionkey_16"), p.variable("regionkey_16")), + p.intersect( + ImmutableListMultimap.builder() + .putAll(p.variable("nationkey_15"), p.variable("nationkey"), p.variable("regionkey_6")) + .putAll(p.variable("regionkey_16"), p.variable("regionkey"), p.variable("regionkey_6")) + .build(), + ImmutableList.of( + p.values(p.variable("nationkey"), p.variable("regionkey")), + p.values(p.variable("regionkey_6"))))))) + .matches( + output( + project( + intersect( + values("nationkey", "regionkey"), + values("regionkey_6"))))); + } + + @Test + public void testExceptNodePruning() + { + assertRuleApplication() + .on(p -> + p.output(ImmutableList.of("regionkey"), ImmutableList.of(p.variable("regionkey_16")), + p.project(Assignments.of(p.variable("regionkey_16"), p.variable("regionkey_16")), + p.except( + ImmutableListMultimap.builder() + .putAll(p.variable("nationkey_15"), p.variable("nationkey"), p.variable("regionkey_6")) + .putAll(p.variable("regionkey_16"), p.variable("regionkey"), p.variable("regionkey_6")) + .build(), + ImmutableList.of( + p.values(p.variable("nationkey"), p.variable("regionkey")), + p.values(p.variable("regionkey_6"))))))) + .matches( + output( + project( + except( + values("nationkey", "regionkey"), + values("regionkey_6"))))); + } + private OptimizerAssert assertRuleApplication() { RuleTester tester = tester();