From f20be3989f9ca425012484e194243c89887abbde Mon Sep 17 00:00:00 2001 From: feilong-liu Date: Wed, 8 Nov 2023 10:51:18 -0800 Subject: [PATCH] Fix prune output rules for intersect and except nodes We should not prune output of intersect and except nodes. --- .../PruneUnreferencedOutputs.java | 10 ++-- .../TestPruneUnreferencedOutputs.java | 50 +++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index f3abbee92778c..ebec09098fe37 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -790,7 +790,7 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext> context) { - ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, true); ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); return new UnionNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping)); } @@ -798,7 +798,7 @@ public PlanNode visitUnion(UnionNode node, RewriteContext> context) { - ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, false); ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); return new IntersectNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping)); } @@ -806,17 +806,17 @@ public PlanNode visitIntersect(IntersectNode node, RewriteContext> context) { - ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context); + ListMultimap rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, false); ImmutableList rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping); return new ExceptNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping)); } - private ListMultimap rewriteSetOperationVariableMapping(SetOperationNode node, RewriteContext> context) + private ListMultimap rewriteSetOperationVariableMapping(SetOperationNode node, RewriteContext> context, boolean pruneUnreferencedOutput) { // Find out which output variables we need to keep ImmutableListMultimap.Builder rewrittenVariableMappingBuilder = ImmutableListMultimap.builder(); for (VariableReferenceExpression variable : node.getOutputVariables()) { - if (context.get().contains(variable)) { + if (context.get().contains(variable) || !pruneUnreferencedOutput) { rewrittenVariableMappingBuilder.putAll( variable, node.getVariableMapping().get(variable)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java index 33e69c9670302..3b65ee64f1cf0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java @@ -19,11 +19,13 @@ import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.assertions.OptimizerAssert; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.plan.WindowNode; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -32,6 +34,8 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.except; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.intersect; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; @@ -89,6 +93,52 @@ public void windowNodePruning() values("user_uuid"))))); } + @Test + public void testIntersectNodePruning() + { + assertRuleApplication() + .on(p -> + p.output(ImmutableList.of("regionkey"), ImmutableList.of(p.variable("regionkey_16")), + p.project(Assignments.of(p.variable("regionkey_16"), p.variable("regionkey_16")), + p.intersect( + ImmutableListMultimap.builder() + .putAll(p.variable("nationkey_15"), p.variable("nationkey"), p.variable("regionkey_6")) + .putAll(p.variable("regionkey_16"), p.variable("regionkey"), p.variable("regionkey_6")) + .build(), + ImmutableList.of( + p.values(p.variable("nationkey"), p.variable("regionkey")), + p.values(p.variable("regionkey_6"))))))) + .matches( + output( + project( + intersect( + values("nationkey", "regionkey"), + values("regionkey_6"))))); + } + + @Test + public void testExceptNodePruning() + { + assertRuleApplication() + .on(p -> + p.output(ImmutableList.of("regionkey"), ImmutableList.of(p.variable("regionkey_16")), + p.project(Assignments.of(p.variable("regionkey_16"), p.variable("regionkey_16")), + p.except( + ImmutableListMultimap.builder() + .putAll(p.variable("nationkey_15"), p.variable("nationkey"), p.variable("regionkey_6")) + .putAll(p.variable("regionkey_16"), p.variable("regionkey"), p.variable("regionkey_6")) + .build(), + ImmutableList.of( + p.values(p.variable("nationkey"), p.variable("regionkey")), + p.values(p.variable("regionkey_6"))))))) + .matches( + output( + project( + except( + values("nationkey", "regionkey"), + values("regionkey_6"))))); + } + private OptimizerAssert assertRuleApplication() { RuleTester tester = tester();