From 408cb8f0efc78fb27d046d547a12d3c9fe87c902 Mon Sep 17 00:00:00 2001 From: feilong-liu Date: Sat, 13 Apr 2024 17:49:58 -0700 Subject: [PATCH] Fix CTE reference in analyzer --- .../presto/hive/TestCteExecution.java | 14 ++++ .../presto/sql/planner/RelationPlanner.java | 2 +- .../presto/sql/planner/SqlPlannerContext.java | 69 +++---------------- .../TestLogicalCteOptimizer.java | 8 +-- 4 files changed, 29 insertions(+), 64 deletions(-) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java index 5b1fccf270e5f..ad388b6dd96e2 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java @@ -1153,6 +1153,20 @@ public void testWrittenIntemediateByteLimit() assertQueryFails(session, testQuery, "Query has exceeded WrittenIntermediate Limit of 0MB.*"); } + @Test + public void testNestedCteWithSameName() + { + String testQuery = "with t1 as ( select orderkey k from orders where orderkey > 5), t2 as ( select orderkey k from orders where orderkey < 10 ), t3 as " + + "( select t1.k, t2.k from t1 left join t2 on t1.k=t2.k ), t4 as ( with t2 as ( select orderkey k from orders where orderkey > 5 ), " + + "t1 as ( select orderkey k from orders where orderkey < 10 ), t3 as ( select t1.k, t2.k from t1 left join t2 on t1.k=t2.k ) select * from t3 ) " + + "select * from t3 except select * from t4"; + QueryRunner queryRunner = getQueryRunner(); + compareResults(queryRunner.execute(getMaterializedSession(), + testQuery), + queryRunner.execute(getSession(), + testQuery)); + } + private void compareResults(MaterializedResult actual, MaterializedResult expected) { compareResults(actual, expected, false); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 8e925730a1d19..a88e3bc586aed 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -190,7 +190,7 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) } else { // cte considered for materialization - String normalizedCteId = context.getCteInfo().normalize(analysis, namedQuery.getQuery(), cteName); + String normalizedCteId = context.getCteInfo().normalize(NodeRef.of(namedQuery.getQuery()), cteName); session.getCteInformationCollector().addCTEReference(cteName, normalizedCteId, namedQuery.isFromView(), true); subPlan = new RelationPlan( new CteReferenceNode(getSourceLocation(node.getLocation()), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java index 42ebb2601ced9..2e6dc763ee882 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java @@ -15,17 +15,13 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; -import com.facebook.presto.sql.tree.DefaultTraversalVisitor; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Query; -import com.facebook.presto.sql.tree.Table; import com.google.common.annotations.VisibleForTesting; -import java.util.Comparator; import java.util.HashMap; import java.util.Map; -import java.util.TreeSet; import static com.facebook.presto.SystemSessionProperties.getMaxLeafNodesInPlan; import static com.facebook.presto.SystemSessionProperties.isLeafNodeLimitEnabled; @@ -73,64 +69,19 @@ public class CteInfo @VisibleForTesting public static final String delimiter = "_*%$_"; // never decreases - private int currentQueryScopeId; + private int prefix; - // Maps a set of Query objects, including the parent query statement and all its referenced statements, - // to a unique scope identifier. Each set of related queries shares the same scope. - Map, String> queryNodeScopeIdMap = new HashMap<>(); + // Map a cte Query to a unique ID, which will be used in CTE reference node to identify the same CTE + private final Map, String> cteQueryUniqueIdMap = new HashMap<>(); - public String normalize(Analysis analysis, Query query, String cteName) + public String normalize(NodeRef queryNodeRef, String cteName) { - QueryReferenceCollectorContext context = new QueryReferenceCollectorContext(); - context.getReferencedQuerySet().add(query); - query.accept(new QueryReferenceCollector(analysis), context); - TreeSet normalizedKey = context.getReferencedQuerySet(); - if (!queryNodeScopeIdMap.containsKey(normalizedKey)) { - queryNodeScopeIdMap.put(normalizedKey, String.valueOf(currentQueryScopeId++)); - } - return queryNodeScopeIdMap.get(normalizedKey) + delimiter + cteName; - } - - private class QueryReferenceCollector - extends DefaultTraversalVisitor - { - private final Analysis analysis; - - public QueryReferenceCollector(Analysis analysis) - { - this.analysis = analysis; - } - - @Override - protected Void visitTable(Table node, QueryReferenceCollectorContext context) - { - Analysis.NamedQuery namedQuery = analysis.getNamedQuery(node); - if (namedQuery != null) { - context.addQuery(namedQuery.getQuery()); - process(namedQuery.getQuery(), context); - } - return null; - } - } - - private class QueryReferenceCollectorContext - { - private final TreeSet referencedQuerySet; - - public QueryReferenceCollectorContext() - { - this.referencedQuerySet = new TreeSet<>(Comparator.comparingInt(Query::hashCode)); - } - - public void addQuery(Query ref) - { - this.referencedQuerySet.add(ref); - } - - public TreeSet getReferencedQuerySet() - { - return referencedQuerySet; + if (cteQueryUniqueIdMap.containsKey(queryNodeRef)) { + return cteQueryUniqueIdMap.get(queryNodeRef) + delimiter + cteName; } + String identityString = String.valueOf(prefix++); + cteQueryUniqueIdMap.put(queryNodeRef, identityString); + return identityString + delimiter + cteName; } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java index fc73ec6a4d402..93ab0e1ea7653 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java @@ -624,13 +624,13 @@ public void testHeuristicMaterializationWithDeepNestedCteUsage3() "SELECT * FROM a\n", anyTree( sequence( - cteProducer(addQueryScopeDelimiter("a", 1), anyTree(tableScan("orders"))), + cteProducer(addQueryScopeDelimiter("a", 2), anyTree(tableScan("orders"))), anyTree(PlanMatchPattern.union( PlanMatchPattern.union( PlanMatchPattern.union( - anyTree(tableScan("orders")), anyTree(cteConsumer(addQueryScopeDelimiter("a", 1)))), - anyTree(cteConsumer(addQueryScopeDelimiter("a", 1)))), - anyTree(cteConsumer(addQueryScopeDelimiter("a", 1)))))))); + anyTree(tableScan("orders")), anyTree(cteConsumer(addQueryScopeDelimiter("a", 2)))), + anyTree(cteConsumer(addQueryScopeDelimiter("a", 2)))), + anyTree(cteConsumer(addQueryScopeDelimiter("a", 2)))))))); } @Test