diff --git a/core/trino-main/src/test/java/io/trino/execution/TestParameterExtractor.java b/core/trino-main/src/test/java/io/trino/execution/TestParameterExtractor.java index 70cdc9711625..3986bb01de1a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestParameterExtractor.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestParameterExtractor.java @@ -65,4 +65,16 @@ public void testLambda() assertThat(ParameterExtractor.getParameterCount(statement)).isEqualTo(1); } + + @Test + public void testWith() + { + // The parameter from CTE has id=0. The parameter from the query has id=1. In the DESCRIBE statement they will be listed following this order. + Statement statement = sqlParser.createStatement("WITH t(a) AS (VALUES ?) SELECT a + ? FROM t", new ParsingOptions()); + assertThat(ParameterExtractor.getParameters(statement)) + .containsExactly( + new Parameter(new NodeLocation(1, 22), 0), + new Parameter(new NodeLocation(1, 38), 1)); + assertThat(ParameterExtractor.getParameterCount(statement)).isEqualTo(2); + } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java index 19fe8278990d..09318ec2a388 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java +++ b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java @@ -45,6 +45,7 @@ import io.trino.sql.tree.Values; import io.trino.sql.tree.WhenClause; import io.trino.sql.tree.WindowDefinition; +import io.trino.sql.tree.With; import java.util.List; import java.util.Optional; @@ -296,4 +297,23 @@ public static Query query(QueryBody body) Optional.empty(), Optional.empty()); } + + public static Query query(With with, Select select, Relation from) + { + return new Query( + Optional.of(with), + new QuerySpecification( + select, + Optional.of(from), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 0723b8fc4af7..5cb99de52557 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -913,11 +913,12 @@ public Node visitProperty(SqlBaseParser.PropertyContext context) @Override public Node visitQuery(SqlBaseParser.QueryContext context) { + Optional with = visitIfPresent(context.with(), With.class); Query body = (Query) visit(context.queryNoWith()); return new Query( getLocation(context), - visitIfPresent(context.with(), With.class), + with, body.getQueryBody(), body.getOrderBy(), body.getOffset(), diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 00b8b0210eb1..1d3be32d8bfd 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -253,6 +253,10 @@ import static io.trino.sql.parser.TreeNodes.rowType; import static io.trino.sql.parser.TreeNodes.simpleType; import static io.trino.sql.testing.TreeAssertions.assertFormattedSql; +import static io.trino.sql.tree.ArithmeticBinaryExpression.Operator.ADD; +import static io.trino.sql.tree.ArithmeticBinaryExpression.Operator.DIVIDE; +import static io.trino.sql.tree.ArithmeticBinaryExpression.Operator.MULTIPLY; +import static io.trino.sql.tree.ArithmeticBinaryExpression.Operator.SUBTRACT; import static io.trino.sql.tree.ArithmeticUnaryExpression.negative; import static io.trino.sql.tree.ArithmeticUnaryExpression.positive; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; @@ -896,25 +900,32 @@ public void testPrecedenceAndAssociativity() new NotExpression(new LongLiteral("1")), new LongLiteral("2"))); - assertExpression("-1 + 2", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, + assertExpression("-1 + 2", new ArithmeticBinaryExpression( + ADD, new LongLiteral("-1"), new LongLiteral("2"))); - assertExpression("1 - 2 - 3", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.SUBTRACT, - new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.SUBTRACT, + assertExpression("1 - 2 - 3", new ArithmeticBinaryExpression( + SUBTRACT, + new ArithmeticBinaryExpression( + SUBTRACT, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3"))); - assertExpression("1 / 2 / 3", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, - new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, + assertExpression("1 / 2 / 3", new ArithmeticBinaryExpression( + DIVIDE, + new ArithmeticBinaryExpression( + DIVIDE, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3"))); - assertExpression("1 + 2 * 3", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, + assertExpression("1 + 2 * 3", new ArithmeticBinaryExpression( + ADD, new LongLiteral("1"), - new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, + new ArithmeticBinaryExpression( + MULTIPLY, new LongLiteral("2"), new LongLiteral("3")))); } @@ -1874,7 +1885,7 @@ public void testMerge() Optional.of(equal(nameReference("c", "action"), new StringLiteral("mod"))), ImmutableList.of( new MergeUpdate.Assignment(new Identifier("qty"), new ArithmeticBinaryExpression( - ArithmeticBinaryExpression.Operator.ADD, + ADD, nameReference("qty"), nameReference("c", "qty"))), new MergeUpdate.Assignment(new Identifier("ts"), new CurrentTime(CurrentTime.Function.TIMESTAMP)))), @@ -2759,6 +2770,20 @@ public void testPrepareWithParameters() Optional.empty(), Optional.of(new Offset(new Parameter(1))), Optional.of(new FetchFirst(new Parameter(2), true))))); + + // The parameter from CTE has id=0. The parameter from the query has id=1. In the DESCRIBE statement they will be listed following this order. + assertStatement("PREPARE myquery FROM WITH t(a) AS (VALUES ROW(?)) SELECT a + ? FROM t", + new Prepare( + identifier("myquery"), + query( + new With( + false, + ImmutableList.of(new WithQuery( + identifier("t"), + query(values(row(new Parameter(0)))), + Optional.of(ImmutableList.of(identifier("a")))))), + selectList(new ArithmeticBinaryExpression(ADD, identifier("a"), new Parameter(1))), + table(QualifiedName.of("t"))))); } @Test diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index 7479f11540c7..2765dff8ccdb 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -1364,6 +1364,17 @@ public void testDescribeInput() .row(4, "decimal(3,2)") .build(); assertEqualsIgnoreOrder(actual, expected); + + session = Session.builder(getSession()) + .addPreparedStatement("my_query", "WITH t(a) AS (VALUES lower(?)) SELECT NOT ? FROM t") + .build(); + actual = computeActual(session, "DESCRIBE INPUT my_query"); + // The parameter from CTE is listed first with id=0. The parameter from the query is listed next with id=1. + expected = resultBuilder(session, BIGINT, VARCHAR) + .row(0, "varchar(0)") + .row(1, "boolean") + .build(); + assertEquals(actual, expected); } @Test