diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java index 10c0bfbfa0889..f152d72006cfc 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java @@ -37,6 +37,8 @@ import java.util.Map; +import static com.facebook.presto.SystemSessionProperties.CTE_MATERIALIZATION_STRATEGY; +import static com.facebook.presto.SystemSessionProperties.CTE_PARTITIONING_PROVIDER_CATALOG; import static com.facebook.presto.SystemSessionProperties.HISTORY_BASED_OPTIMIZATION_PLAN_CANONICALIZATION_STRATEGY; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.PARTIAL_AGGREGATION_STRATEGY; @@ -258,6 +260,23 @@ public void testPartialAggStatisticsGroupByPartKey() } } + @Test + public void testHistoryBasedStatsCalculatorCTE() + { + String sql = "with t1 as (select orderkey, orderstatus from orders where totalprice > 100), t2 as (select orderkey, totalprice from orders where custkey > 100) " + + "select orderstatus, sum(totalprice) from t1 join t2 on t1.orderkey=t2.orderkey group by orderstatus"; + Session cteMaterialization = Session.builder(defaultSession()) + .setSystemProperty(CTE_MATERIALIZATION_STRATEGY, "ALL") + .setSystemProperty(CTE_PARTITIONING_PROVIDER_CATALOG, "hive") + .build(); + // CBO Statistics + assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(Double.NaN))); + + // HBO Statistics + executeAndTrackHistory(sql, cteMaterialization); + assertPlan(cteMaterialization, sql, anyTree(node(ProjectNode.class, anyTree(any())).withOutputRowCount(3))); + } + @Override protected void assertPlan(@Language("SQL") String query, PlanMatchPattern pattern) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index acb3155e71012..a8ca6e635a159 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -24,6 +24,8 @@ import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.GroupingSetDescriptor; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.CteConsumerNode; +import com.facebook.presto.spi.plan.CteProducerNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; @@ -53,6 +55,7 @@ import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; @@ -830,6 +833,25 @@ public Optional visitAggregation(AggregationNode node, Context context return Optional.of(canonicalPlan); } + @Override + public Optional visitSequence(SequenceNode node, Context context) + { + node.getCteProducers().forEach(x -> x.accept(this, context)); + return node.getPrimarySource().accept(this, context); + } + + @Override + public Optional visitCteProducer(CteProducerNode node, Context context) + { + return node.getSource().accept(this, context); + } + + @Override + public Optional visitCteConsumer(CteConsumerNode node, Context context) + { + return node.getOriginalSource().accept(this, context); + } + private Aggregation getCanonicalAggregation(Aggregation aggregation, Map context) { return new Aggregation(