diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java index 5e89895eeddb0..2e801235beb5f 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java @@ -167,6 +167,28 @@ public void testRefinedCtesOutsideScope() queryRunner.execute(getSession(), testQuery)); } + @Test + public void testRedefinedCteWithSameDefinitionDifferentBase() + { + String testQuery = "SELECT (with test_base AS (SELECT colB FROM (VALUES (1)) AS TempTable(colB)), \n" + + "test_cte as ( SELECT colB FROM test_base)\n" + + "SELECT * FROM test_cte\n" + + "),\n" + + "(WITH test_base AS (\n" + + " SELECT text_column\n" + + " FROM (VALUES ('Some Text', 9)) AS t (text_column, number_column)\n" + + "), \n" + + "test_cte AS (\n" + + " SELECT * FROM test_base\n" + + ")\n" + + "SELECT CONCAT(text_column , 'XYZ') FROM test_cte\n" + + ")\n"; + QueryRunner queryRunner = getQueryRunner(); + compareResults( + queryRunner.execute(getMaterializedSession(), testQuery), + queryRunner.execute(getSession(), testQuery)); + } + @Test public void testComplexRefinedCtesOutsideScope() { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 986f84c5e5176..2d2f39ce6cf1d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -184,15 +184,13 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) if (namedQuery.isFromView()) { cteName = createQualifiedObjectName(session, node, node.getName()).toString(); } - context.getNestedCteStack().push(cteName, namedQuery.getQuery()); RelationPlan subPlan = process(namedQuery.getQuery(), context); - context.getNestedCteStack().pop(namedQuery.getQuery()); boolean shouldBeMaterialized = getCteMaterializationStrategy(session).equals(ALL) && isCteMaterializable(subPlan.getRoot().getOutputVariables()); session.getCteInformationCollector().addCTEReference(cteName, namedQuery.isFromView(), shouldBeMaterialized); if (shouldBeMaterialized) { subPlan = new RelationPlan( new CteReferenceNode(getSourceLocation(node.getLocation()), - idAllocator.getNextId(), subPlan.getRoot(), context.getNestedCteStack().getRawPath(cteName)), + idAllocator.getNextId(), subPlan.getRoot(), context.getCteInfo().normalize(analysis, namedQuery.getQuery(), cteName)), subPlan.getScope(), subPlan.getFieldMappings()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java index cb1f538e6c4bf..42ebb2601ced9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SqlPlannerContext.java @@ -15,13 +15,17 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; +import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Query; +import com.facebook.presto.sql.tree.Table; import com.google.common.annotations.VisibleForTesting; +import java.util.Comparator; import java.util.HashMap; import java.util.Map; -import java.util.Stack; +import java.util.TreeSet; import static com.facebook.presto.SystemSessionProperties.getMaxLeafNodesInPlan; import static com.facebook.presto.SystemSessionProperties.isLeafNodeLimitEnabled; @@ -34,18 +38,18 @@ public class SqlPlannerContext private int leafNodesInLogicalPlan; private final SqlToRowExpressionTranslator.Context translatorContext; - private final NestedCteStack nestedCteStack; + private final CteInfo cteInfo; public SqlPlannerContext(int leafNodesInLogicalPlan) { this.leafNodesInLogicalPlan = leafNodesInLogicalPlan; this.translatorContext = new SqlToRowExpressionTranslator.Context(); - this.nestedCteStack = new NestedCteStack(); + this.cteInfo = new CteInfo(); } - public NestedCteStack getNestedCteStack() + public CteInfo getCteInfo() { - return nestedCteStack; + return cteInfo; } public SqlToRowExpressionTranslator.Context getTranslatorContext() @@ -64,57 +68,69 @@ public void incrementLeafNodes(Session session) } } - public class NestedCteStack + public class CteInfo { @VisibleForTesting public static final String delimiter = "_*%$_"; - private final Stack cteStack; - private final Map rawCtePathMap; + // never decreases + private int currentQueryScopeId; - public NestedCteStack() - { - this.cteStack = new Stack<>(); - this.rawCtePathMap = new HashMap<>(); - } + // Maps a set of Query objects, including the parent query statement and all its referenced statements, + // to a unique scope identifier. Each set of related queries shares the same scope. + Map, String> queryNodeScopeIdMap = new HashMap<>(); - public void push(String cteName, Query query) + public String normalize(Analysis analysis, Query query, String cteName) { - this.cteStack.push(cteName); - if (query.getWith().isPresent()) { - // All ctes defined in this context should have their paths updated - query.getWith().get().getQueries().forEach(with -> this.addNestedCte(with.getName().toString())); + QueryReferenceCollectorContext context = new QueryReferenceCollectorContext(); + context.getReferencedQuerySet().add(query); + query.accept(new QueryReferenceCollector(analysis), context); + TreeSet normalizedKey = context.getReferencedQuerySet(); + if (!queryNodeScopeIdMap.containsKey(normalizedKey)) { + queryNodeScopeIdMap.put(normalizedKey, String.valueOf(currentQueryScopeId++)); } + return queryNodeScopeIdMap.get(normalizedKey) + delimiter + cteName; } - public void pop(Query query) + private class QueryReferenceCollector + extends DefaultTraversalVisitor { - this.cteStack.pop(); - if (query.getWith().isPresent()) { - query.getWith().get().getQueries().forEach(with -> this.removeNestedCte(with.getName().toString())); + private final Analysis analysis; + + public QueryReferenceCollector(Analysis analysis) + { + this.analysis = analysis; } - } - public String getRawPath(String cteName) - { - if (!this.rawCtePathMap.containsKey(cteName)) { - return cteName; + @Override + protected Void visitTable(Table node, QueryReferenceCollectorContext context) + { + Analysis.NamedQuery namedQuery = analysis.getNamedQuery(node); + if (namedQuery != null) { + context.addQuery(namedQuery.getQuery()); + process(namedQuery.getQuery(), context); + } + return null; } - return this.rawCtePathMap.get(cteName); } - private void addNestedCte(String cteName) + private class QueryReferenceCollectorContext { - this.rawCtePathMap.put(cteName, getCurrentRelativeCtePath() + delimiter + cteName); - } + private final TreeSet referencedQuerySet; - private void removeNestedCte(String cteName) - { - this.rawCtePathMap.remove(cteName); - } + public QueryReferenceCollectorContext() + { + this.referencedQuerySet = new TreeSet<>(Comparator.comparingInt(Query::hashCode)); + } - public String getCurrentRelativeCtePath() - { - return String.join(delimiter, cteStack); + public void addQuery(Query ref) + { + this.referencedQuerySet.add(ref); + } + + public TreeSet getReferencedQuerySet() + { + return referencedQuerySet; + } } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java index 57cff35abbd97..9dc8e0c1befc0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLogicalCteOptimizer.java @@ -23,11 +23,12 @@ import java.util.List; import static com.facebook.presto.SystemSessionProperties.CTE_MATERIALIZATION_STRATEGY; -import static com.facebook.presto.sql.planner.SqlPlannerContext.NestedCteStack.delimiter; +import static com.facebook.presto.sql.planner.SqlPlannerContext.CteInfo.delimiter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.cteConsumer; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.cteProducer; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.lateral; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sequence; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; @@ -39,10 +40,37 @@ public class TestLogicalCteOptimizer public void testConvertSimpleCte() { assertUnitPlan("WITH temp as (SELECT orderkey FROM ORDERS) " + - "SELECT * FROM temp t1 ", + "SELECT * FROM temp t JOIN temp t2 ON true ", anyTree( - sequence(cteProducer("temp", anyTree(tableScan("orders"))), - anyTree(cteConsumer("temp"))))); + sequence(cteProducer(addQueryScopeDelimiter("temp", 0), anyTree(tableScan("orders"))), + anyTree(cteConsumer(addQueryScopeDelimiter("temp", 0)))))); + } + + @Test + public void testSimpleRedefinedCteWithSameNameDefinedAgain() + { + assertUnitPlan("WITH \n" + + "test_base AS (SELECT colB FROM (VALUES (1), (2)) AS TempTable(colB)),\n" + + "test AS (\n" + + " \n" + + " WITH test_base as (SELECT colA FROM (VALUES (1), (2)) AS TempTable(colA)),\n" + + " test1 AS (\n" + + " WITH test2 AS(\n" + + " SELECT * FROM test_base\n" + + " )\n" + + " SELECT * FROM test2\n" + + " )\n" + + " SELECT * FROM test1\n" + + ")\n" + + "\n" + + "SELECT * FROM test", + anyTree( + sequence( + cteProducer(addQueryScopeDelimiter("test", 3), anyTree(cteConsumer(addQueryScopeDelimiter("test1", 2)))), + cteProducer(addQueryScopeDelimiter("test1", 2), anyTree(cteConsumer(addQueryScopeDelimiter("test2", 1)))), + cteProducer(addQueryScopeDelimiter("test2", 1), anyTree(cteConsumer(addQueryScopeDelimiter("test_base", 0)))), + cteProducer(addQueryScopeDelimiter("test_base", 0), anyTree(values("colA"))), + anyTree(cteConsumer(addQueryScopeDelimiter("test", 3)))))); } @Test @@ -53,9 +81,9 @@ public void testSimpleRedefinedCteWithSameName() "SELECT * FROM temp", anyTree( sequence( - cteProducer("temp", anyTree(cteConsumer("temp" + delimiter + "temp"))), - cteProducer("temp" + delimiter + "temp", anyTree(tableScan("orders"))), - anyTree(cteConsumer("temp"))))); + cteProducer(addQueryScopeDelimiter("temp", 1), anyTree(cteConsumer(addQueryScopeDelimiter("temp", 0)))), + cteProducer(addQueryScopeDelimiter("temp", 0), anyTree(tableScan("orders"))), + anyTree(cteConsumer(addQueryScopeDelimiter("temp", 1)))))); } @Test @@ -76,15 +104,74 @@ public void testComplexRedefinedNestedCtes() "SELECT cte3.*, cte2.orderkey FROM cte3 JOIN cte2 ON cte3.custkey = cte2.orderkey", anyTree( sequence( - cteProducer("cte3", anyTree(tableScan("customer"))), - cteProducer("cte2", anyTree(cteConsumer("cte2" + delimiter + "cte3"))), - cteProducer("cte2" + delimiter + "cte3", anyTree(cteConsumer("cte2" + delimiter + "cte3" + delimiter + "cte4"))), - cteProducer("cte2" + delimiter + "cte3" + delimiter + "cte4", anyTree(cteConsumer("cte1"))), - cteProducer("cte1", anyTree(tableScan("orders"))), + cteProducer(addQueryScopeDelimiter("cte3", 0), anyTree(tableScan("customer"))), + cteProducer(addQueryScopeDelimiter("cte2", 4), anyTree(cteConsumer(addQueryScopeDelimiter("cte3", 3)))), + cteProducer(addQueryScopeDelimiter("cte3", 3), anyTree(cteConsumer(addQueryScopeDelimiter("cte4", 2)))), + cteProducer(addQueryScopeDelimiter("cte4", 2), anyTree(cteConsumer(addQueryScopeDelimiter("cte1", 1)))), + cteProducer(addQueryScopeDelimiter("cte1", 1), anyTree(tableScan("orders"))), anyTree( join( - anyTree(cteConsumer("cte3")), - anyTree(cteConsumer("cte2"))))))); + anyTree(cteConsumer(addQueryScopeDelimiter("cte3", 0))), + anyTree(cteConsumer(addQueryScopeDelimiter("cte2", 4)))))))); + } + + @Test + public void testRedefinedCteConflictingNamesInDifferentScope() + { + assertUnitPlan("WITH test AS (SELECT colA FROM (VALUES (1), (2)) AS TempTable(colA)),\n" + + " _query AS (\n" + + " with test AS (\n" + + " SELECT * FROM test\n" + + " )\n" + + " SELECT * FROM test\n" + + " )\n" + + " SELECT * FROM _query", + anyTree( + sequence( + cteProducer(addQueryScopeDelimiter("_query", 2), anyTree(cteConsumer(addQueryScopeDelimiter("test", 1)))), + cteProducer(addQueryScopeDelimiter("test", 1), anyTree(cteConsumer(addQueryScopeDelimiter("test", 0)))), + cteProducer(addQueryScopeDelimiter("test", 0), anyTree(values("colA"))), + anyTree(cteConsumer(addQueryScopeDelimiter("_query", 2)))))); + } + + @Test + public void testCtesDefinedInEntirelyDifferentScope() + { + // From clause is visited first + assertUnitPlan("SELECT \n" + + " *, (WITH T as (SELECT colA FROM (VALUES (1), (2)) AS TempTable(colA)) SELECT * FROM T)\n" + + "FROM (\n" + + " WITH T AS ( \n" + + " SELECT ColumnA, ColumnB FROM (\n" + + " VALUES \n" + + " (1, 'A'),\n" + + " (2, 'B'),\n" + + " (3, 'C'),\n" + + " (4, 'D')\n" + + " ) AS TempTable(ColumnA, ColumnB)\n" + + " )\n" + + " SELECT * FROM T JOIN T ON TRUE" + + ")", + anyTree( + sequence( + cteProducer(addQueryScopeDelimiter("T", 0), anyTree(values("ColumnA", "ColumnB"))), + cteProducer(addQueryScopeDelimiter("T", 1), anyTree(values("colA"))), + anyTree(lateral(ImmutableList.of(), + anyTree(join(anyTree(cteConsumer(addQueryScopeDelimiter("T", 0))), anyTree(cteConsumer(addQueryScopeDelimiter("T", 0))))), + anyTree(cteConsumer(addQueryScopeDelimiter("T", 1)))))))); + } + + @Test + public void testNestedCtesReused() + { + assertUnitPlan("WITH cte1 AS ( WITH cte2 as (SELECT orderkey FROM ORDERS WHERE orderkey < 100)" + + "SELECT * FROM cte2)" + + "SELECT * FROM cte1 JOIN cte1 ON true", + anyTree( + sequence( + cteProducer(addQueryScopeDelimiter("cte1", 1), anyTree(cteConsumer(addQueryScopeDelimiter("cte2", 0)))), + cteProducer(addQueryScopeDelimiter("cte2", 0), anyTree(tableScan("orders"))), + anyTree(join(anyTree(cteConsumer(addQueryScopeDelimiter("cte1", 1))), anyTree(cteConsumer(addQueryScopeDelimiter("cte1", 1)))))))); } @Test @@ -96,10 +183,10 @@ public void testRedefinedCtesInDifferentScope() "SELECT * FROM cte2 JOIN cte1 ON true", anyTree( sequence( - cteProducer("cte2", anyTree(tableScan("customer"))), - cteProducer("cte1", anyTree(cteConsumer("cte1" + delimiter + "cte2"))), - cteProducer("cte1" + delimiter + "cte2", anyTree(tableScan("orders"))), - anyTree(join(anyTree(cteConsumer("cte2")), anyTree(cteConsumer("cte1"))))))); + cteProducer(addQueryScopeDelimiter("cte2", 0), anyTree(tableScan("customer"))), + cteProducer(addQueryScopeDelimiter("cte1", 2), anyTree(cteConsumer(addQueryScopeDelimiter("cte2", 1)))), + cteProducer(addQueryScopeDelimiter("cte2", 1), anyTree(tableScan("orders"))), + anyTree(join(anyTree(cteConsumer(addQueryScopeDelimiter("cte2", 0))), anyTree(cteConsumer(addQueryScopeDelimiter("cte1", 2)))))))); } @Test @@ -109,9 +196,9 @@ public void testNestedCte() " temp2 as (SELECT * FROM temp1) " + "SELECT * FROM temp2", anyTree( - sequence(cteProducer("temp2", anyTree(cteConsumer("temp1"))), - cteProducer("temp1", anyTree(tableScan("orders"))), - anyTree(cteConsumer("temp2"))))); + sequence(cteProducer(addQueryScopeDelimiter("temp2", 1), anyTree(cteConsumer(addQueryScopeDelimiter("temp1", 0)))), + cteProducer(addQueryScopeDelimiter("temp1", 0), anyTree(tableScan("orders"))), + anyTree(cteConsumer(addQueryScopeDelimiter("temp2", 1)))))); } @Test @@ -121,9 +208,9 @@ public void testMultipleIndependentCtes() " temp2 as (SELECT custkey FROM CUSTOMER) " + "SELECT * FROM temp1, temp2", anyTree( - sequence(cteProducer("temp1", anyTree(tableScan("orders"))), - cteProducer("temp2", anyTree(tableScan("customer"))), - anyTree(join(anyTree(cteConsumer("temp1")), anyTree(cteConsumer("temp2"))))))); + sequence(cteProducer(addQueryScopeDelimiter("temp1", 0), anyTree(tableScan("orders"))), + cteProducer(addQueryScopeDelimiter("temp2", 1), anyTree(tableScan("customer"))), + anyTree(join(anyTree(cteConsumer(addQueryScopeDelimiter("temp1", 0))), anyTree(cteConsumer(addQueryScopeDelimiter("temp2", 1)))))))); } @Test @@ -133,9 +220,9 @@ public void testDependentCtes() " temp2 as (SELECT orderkey FROM temp1) " + "SELECT * FROM temp2 , temp1", anyTree( - sequence(cteProducer("temp2", anyTree(cteConsumer("temp1"))), - cteProducer("temp1", anyTree(tableScan("orders"))), - anyTree(join(anyTree(cteConsumer("temp2")), anyTree(cteConsumer("temp1"))))))); + sequence(cteProducer(addQueryScopeDelimiter("temp2", 1), anyTree(cteConsumer(addQueryScopeDelimiter("temp1", 0)))), + cteProducer(addQueryScopeDelimiter("temp1", 0), anyTree(tableScan("orders"))), + anyTree(join(anyTree(cteConsumer(addQueryScopeDelimiter("temp2", 1))), anyTree(cteConsumer(addQueryScopeDelimiter("temp1", 0)))))))); } @Test @@ -147,15 +234,15 @@ public void testComplexCteWithJoins() "SELECT li.orderkey, s.suppkey, s.name FROM cte_line_item li JOIN SUPPLIER s ON li.suppkey = s.suppkey", anyTree( sequence( - cteProducer("cte_line_item", + cteProducer(addQueryScopeDelimiter("cte_line_item", 1), anyTree( join( anyTree(tableScan("lineitem")), - anyTree(cteConsumer("cte_orders"))))), - cteProducer("cte_orders", anyTree(tableScan("orders"))), + anyTree(cteConsumer(addQueryScopeDelimiter("cte_orders", 0)))))), + cteProducer(addQueryScopeDelimiter("cte_orders", 0), anyTree(tableScan("orders"))), anyTree( join( - anyTree(cteConsumer("cte_line_item")), + anyTree(cteConsumer(addQueryScopeDelimiter("cte_line_item", 1))), anyTree(tableScan("supplier"))))))); } @@ -180,8 +267,8 @@ public void testSimplePersistentCteWithRowTypeAndNonRowType() ") SELECT * FROM temp", anyTree( sequence( - cteProducer("temp", anyTree(values("status", "amount"))), - anyTree(cteConsumer("temp"))))); + cteProducer(addQueryScopeDelimiter("temp", 0), anyTree(values("status", "amount"))), + anyTree(cteConsumer(addQueryScopeDelimiter("temp", 0)))))); } @Test @@ -195,6 +282,11 @@ public void testNoPersistentCteWithZeroLengthVarcharType() anyTree(values("text_column", "number_column"))); } + private String addQueryScopeDelimiter(String cteName, int scope) + { + return String.valueOf(scope) + delimiter + cteName; + } + private void assertUnitPlan(String sql, PlanMatchPattern pattern) { List optimizers = ImmutableList.of(