diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java index aaaf1c9232b6..deec6a8bd6dd 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java @@ -58,6 +58,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.OptionalLong; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -97,6 +98,7 @@ public class Analysis private final Map, List> outputExpressions = new LinkedHashMap<>(); private final Map, List> windowFunctions = new LinkedHashMap<>(); private final Map, List> orderByWindowFunctions = new LinkedHashMap<>(); + private final Map, OptionalLong> limit = new LinkedHashMap<>(); private final Map, Expression> joins = new LinkedHashMap<>(); private final Map, JoinUsingAnalysis> joinUsing = new LinkedHashMap<>(); @@ -316,6 +318,22 @@ public List getOrderByExpressions(Node node) return orderByExpressions.get(NodeRef.of(node)); } + public void setLimit(Node node, OptionalLong rowCount) + { + limit.put(NodeRef.of(node), rowCount); + } + + public void setLimit(Node node, long rowCount) + { + limit.put(NodeRef.of(node), OptionalLong.of(rowCount)); + } + + public OptionalLong getLimit(Node node) + { + checkState(limit.containsKey(NodeRef.of(node)), "missing LIMIT value for node %s", node); + return limit.get(NodeRef.of(node)); + } + public void setOutputExpressions(Node node, List expressions) { outputExpressions.put(NodeRef.of(node), ImmutableList.copyOf(expressions)); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java index 19308f965133..658c1aedc90f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java @@ -99,4 +99,7 @@ public enum SemanticErrorCode MISSING_ROLE, TOO_MANY_GROUPING_SETS, + + INVALID_FETCH_FIRST_ROW_COUNT, + INVALID_LIMIT_ROW_COUNT, } diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java index 9e2fd5fb985d..aac1e9079417 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java @@ -78,6 +78,7 @@ import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; import io.prestosql.sql.tree.ExpressionTreeRewriter; +import io.prestosql.sql.tree.FetchFirst; import io.prestosql.sql.tree.FieldReference; import io.prestosql.sql.tree.FrameBound; import io.prestosql.sql.tree.FunctionCall; @@ -94,6 +95,7 @@ import io.prestosql.sql.tree.JoinOn; import io.prestosql.sql.tree.JoinUsing; import io.prestosql.sql.tree.Lateral; +import io.prestosql.sql.tree.Limit; import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.NaturalJoin; import io.prestosql.sql.tree.Node; @@ -141,6 +143,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalLong; import java.util.Set; import java.util.stream.Collectors; @@ -176,6 +179,8 @@ import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME; import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_PROPERTY; import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; +import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_FETCH_FIRST_ROW_COUNT; +import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LIMIT_ROW_COUNT; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_WINDOW_FRAME; @@ -735,6 +740,10 @@ protected Scope visitQuery(Query node, Optional scope) analysis.setOrderByExpressions(node, emptyList()); } + if (node.getLimit().isPresent()) { + analyzeLimit(node.getLimit().get()); + } + // Input fields == Output fields analysis.setOutputExpressions(node, descriptorToFields(queryBodyScope)); @@ -1028,6 +1037,10 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional analysis.setOrderByExpressions(node, emptyList()); } + if (node.getLimit().isPresent()) { + analyzeLimit(node.getLimit().get()); + } + List sourceExpressions = new ArrayList<>(outputExpressions); node.getHaving().ifPresent(sourceExpressions::add); @@ -2113,6 +2126,56 @@ private List analyzeOrderBy(Node node, List sortItems, Sco return orderByFields; } + private void analyzeLimit(Node node) + { + checkState( + node instanceof FetchFirst || node instanceof Limit, + "Invalid limit node type. Expected: FetchFirst or Limit. Actual: %s", node.getClass().getName()); + if (node instanceof FetchFirst) { + analyzeLimit((FetchFirst) node); + } + else { + analyzeLimit((Limit) node); + } + } + + private void analyzeLimit(FetchFirst node) + { + if (!node.getRowCount().isPresent()) { + analysis.setLimit(node, 1); + } + else { + long rowCount; + try { + rowCount = Long.parseLong(node.getRowCount().get()); + } + catch (NumberFormatException e) { + throw new SemanticException(INVALID_FETCH_FIRST_ROW_COUNT, node, "Invalid FETCH FIRST row count: %s", node.getRowCount().get()); + } + if (rowCount <= 0) { + throw new SemanticException(INVALID_FETCH_FIRST_ROW_COUNT, node, "FETCH FIRST row count must be positive (actual value: %s)", rowCount); + } + analysis.setLimit(node, rowCount); + } + } + + private void analyzeLimit(Limit node) + { + if (node.getLimit().equalsIgnoreCase("all")) { + analysis.setLimit(node, OptionalLong.empty()); + } + else { + long rowCount; + try { + rowCount = Long.parseLong(node.getLimit()); + } + catch (NumberFormatException e) { + throw new SemanticException(INVALID_LIMIT_ROW_COUNT, node, "Invalid LIMIT row count: %s", node.getLimit()); + } + analysis.setLimit(node, rowCount); + } + } + private Scope createAndAssignScope(Node node, Optional parentScope) { return createAndAssignScope(node, parentScope, emptyList()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java index 8efe82ed5bbc..f884eaf0eb24 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java @@ -865,7 +865,7 @@ private PlanBuilder sort(PlanBuilder subPlan, QuerySpecification node) return sort(subPlan, node.getOrderBy(), node.getLimit(), analysis.getOrderByExpressions(node)); } - private PlanBuilder sort(PlanBuilder subPlan, Optional orderBy, Optional limit, List orderByExpressions) + private PlanBuilder sort(PlanBuilder subPlan, Optional orderBy, Optional limit, List orderByExpressions) { if (!orderBy.isPresent()) { return subPlan; @@ -887,8 +887,8 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional orderBy, Optiona PlanNode planNode; OrderingScheme orderingScheme = new OrderingScheme(orderBySymbols.build(), orderings); - if (limit.isPresent() && !limit.get().equalsIgnoreCase("all")) { - planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), Long.parseLong(limit.get()), orderingScheme, TopNNode.Step.SINGLE); + if (limit.isPresent() && analysis.getLimit(limit.get()).isPresent()) { + planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), analysis.getLimit(limit.get()).getAsLong(), orderingScheme, TopNNode.Step.SINGLE); } else { planNode = new SortNode(idAllocator.getNextId(), subPlan.getRoot(), orderingScheme); @@ -907,13 +907,10 @@ private PlanBuilder limit(PlanBuilder subPlan, QuerySpecification node) return limit(subPlan, node.getOrderBy(), node.getLimit()); } - private PlanBuilder limit(PlanBuilder subPlan, Optional orderBy, Optional limit) + private PlanBuilder limit(PlanBuilder subPlan, Optional orderBy, Optional limit) { - if (!orderBy.isPresent() && limit.isPresent()) { - if (!limit.get().equalsIgnoreCase("all")) { - long limitValue = Long.parseLong(limit.get()); - subPlan = subPlan.withNewRoot(new LimitNode(idAllocator.getNextId(), subPlan.getRoot(), limitValue, false)); - } + if (!orderBy.isPresent() && limit.isPresent() && analysis.getLimit(limit.get()).isPresent()) { + return subPlan.withNewRoot(new LimitNode(idAllocator.getNextId(), subPlan.getRoot(), analysis.getLimit(limit.get()).getAsLong(), false)); } return subPlan; diff --git a/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeInputRewrite.java b/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeInputRewrite.java index ccbf87bdb631..b90a6df8049b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeInputRewrite.java +++ b/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeInputRewrite.java @@ -27,6 +27,7 @@ import io.prestosql.sql.tree.AstVisitor; import io.prestosql.sql.tree.DescribeInput; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.Limit; import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NullLiteral; @@ -113,10 +114,10 @@ protected Node visitDescribeInput(DescribeInput node, Void context) // return the positions and types of all parameters Row[] rows = parameters.stream().map(parameter -> createDescribeInputRow(parameter, analysis)).toArray(Row[]::new); - Optional limit = Optional.empty(); + Optional limit = Optional.empty(); if (rows.length == 0) { rows = new Row[] {row(new NullLiteral(), new NullLiteral())}; - limit = Optional.of("0"); + limit = Optional.of(new Limit("0")); } return simpleQuery( diff --git a/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeOutputRewrite.java b/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeOutputRewrite.java index fa3db982dd42..f58a0cb704df 100644 --- a/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeOutputRewrite.java +++ b/presto-main/src/main/java/io/prestosql/sql/rewrite/DescribeOutputRewrite.java @@ -29,6 +29,7 @@ import io.prestosql.sql.tree.BooleanLiteral; import io.prestosql.sql.tree.DescribeOutput; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.Limit; import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NullLiteral; @@ -103,12 +104,12 @@ protected Node visitDescribeOutput(DescribeOutput node, Void context) Analyzer analyzer = new Analyzer(session, metadata, parser, accessControl, queryExplainer, parameters, warningCollector); Analysis analysis = analyzer.analyze(statement, true); - Optional limit = Optional.empty(); + Optional limit = Optional.empty(); Row[] rows = analysis.getRootScope().getRelationType().getVisibleFields().stream().map(field -> createDescribeOutputRow(field, analysis)).toArray(Row[]::new); if (rows.length == 0) { NullLiteral nullLiteral = new NullLiteral(); rows = new Row[] {row(nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral)}; - limit = Optional.of("0"); + limit = Optional.of(new Limit("0")); } return simpleQuery( selectList( diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java index 60e0bf789266..8500e0ffd19e 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java @@ -83,6 +83,8 @@ import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME; import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_PROPERTY; import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; +import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_FETCH_FIRST_ROW_COUNT; +import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LIMIT_ROW_COUNT; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; @@ -329,6 +331,19 @@ public void testOrderByNonComparable() assertFails(TYPE_MISMATCH, "SELECT x FROM (SELECT approx_set(1) x) ORDER BY x"); } + @Test + public void testFetchFirstInvalidRowCount() + { + assertFails(INVALID_FETCH_FIRST_ROW_COUNT, "SELECT * FROM t1 FETCH FIRST 987654321098765432109876543210 ROWS ONLY"); + assertFails(INVALID_FETCH_FIRST_ROW_COUNT, "SELECT * FROM t1 FETCH FIRST 0 ROWS ONLY"); + } + + @Test + public void testLimitInvalidRowCount() + { + assertFails(INVALID_LIMIT_ROW_COUNT, "SELECT * FROM t1 LIMIT 987654321098765432109876543210"); + } + @Test public void testNestedAggregation() { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java index cd844bf314e9..a1ff18e28936 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java @@ -69,6 +69,7 @@ import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.functionCall; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.limit; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.markDistinct; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.node; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; @@ -833,4 +834,30 @@ public void testRemoveAggregationInSemiJoin() "SELECT custkey FROM orders WHERE custkey IN (SELECT distinct custkey FROM customer)", AggregationNode.class); } + + @Test + public void testOrderByFetch() + { + assertPlan( + "SELECT * FROM nation ORDER BY name FETCH FIRST 2 ROWS ONLY", + anyTree( + topN( + 2, + ImmutableList.of(sort("NAME", ASCENDING, LAST)), + TopNNode.Step.PARTIAL, + tableScan("nation", ImmutableMap.of( + "NAME", "name"))))); + } + + @Test + public void testFetch() + { + assertPlan( + "SELECT * FROM nation FETCH FIRST 2 ROWS ONLY", + anyTree( + limit( + 2, + any( + tableScan("nation"))))); + } } diff --git a/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 index 4e11c4c39f07..785e79877959 100644 --- a/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 @@ -149,7 +149,7 @@ property queryNoWith: queryTerm (ORDER BY sortItem (',' sortItem)*)? - (LIMIT limit=(INTEGER_VALUE | ALL))? + ((LIMIT limit=(INTEGER_VALUE | ALL)) | (FETCH (FIRST | NEXT) (fetchFirst=INTEGER_VALUE)? (ROW | ROWS) ONLY))? ; queryTerm @@ -492,14 +492,14 @@ nonReserved | CALL | CASCADE | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CURRENT | DATA | DATE | DAY | DEFINER | DESC | DISTRIBUTED | EXCLUDING | EXPLAIN - | FILTER | FIRST | FOLLOWING | FORMAT | FUNCTIONS + | FETCH | FILTER | FIRST | FOLLOWING | FORMAT | FUNCTIONS | GRANT | GRANTED | GRANTS | GRAPHVIZ | HOUR | IF | INCLUDING | INPUT | INTERVAL | INVOKER | IO | ISOLATION | JSON | LAST | LATERAL | LEVEL | LIMIT | LOGICAL | MAP | MINUTE | MONTH - | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS + | NEXT | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS | ONLY | OPTION | ORDINALITY | OUTPUT | OVER | PARTITION | PARTITIONS | PATH | POSITION | PRECEDING | PRIVILEGES | PROPERTIES | RANGE | READ | RENAME | REPEATABLE | REPLACE | RESET | RESTRICT | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS @@ -569,6 +569,7 @@ EXISTS: 'EXISTS'; EXPLAIN: 'EXPLAIN'; EXTRACT: 'EXTRACT'; FALSE: 'FALSE'; +FETCH: 'FETCH'; FILTER: 'FILTER'; FIRST: 'FIRST'; FOLLOWING: 'FOLLOWING'; @@ -613,6 +614,7 @@ MAP: 'MAP'; MINUTE: 'MINUTE'; MONTH: 'MONTH'; NATURAL: 'NATURAL'; +NEXT: 'NEXT'; NFC : 'NFC'; NFD : 'NFD'; NFKC : 'NFKC'; diff --git a/presto-parser/src/main/java/io/prestosql/sql/QueryUtil.java b/presto-parser/src/main/java/io/prestosql/sql/QueryUtil.java index 50e15f361d9b..b29250bdf6c0 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/QueryUtil.java +++ b/presto-parser/src/main/java/io/prestosql/sql/QueryUtil.java @@ -23,6 +23,7 @@ import io.prestosql.sql.tree.GroupBy; import io.prestosql.sql.tree.Identifier; import io.prestosql.sql.tree.LogicalBinaryExpression; +import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.OrderBy; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.Query; @@ -199,7 +200,7 @@ public static Query simpleQuery(Select select, Relation from, Optional where, Optional groupBy, Optional having, Optional orderBy, Optional limit) + public static Query simpleQuery(Select select, Relation from, Optional where, Optional groupBy, Optional having, Optional orderBy, Optional limit) { return query(new QuerySpecification( select, diff --git a/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java b/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java index 794fe6a987b0..337bf3624070 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java @@ -46,6 +46,7 @@ import io.prestosql.sql.tree.ExplainOption; import io.prestosql.sql.tree.ExplainType; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.FetchFirst; import io.prestosql.sql.tree.Grant; import io.prestosql.sql.tree.GrantRoles; import io.prestosql.sql.tree.GrantorSpecification; @@ -59,6 +60,7 @@ import io.prestosql.sql.tree.JoinUsing; import io.prestosql.sql.tree.Lateral; import io.prestosql.sql.tree.LikeClause; +import io.prestosql.sql.tree.Limit; import io.prestosql.sql.tree.NaturalJoin; import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.OrderBy; @@ -262,10 +264,8 @@ protected Void visitQuery(Query node, Integer indent) } if (node.getLimit().isPresent()) { - append(indent, "LIMIT " + node.getLimit().get()) - .append('\n'); + process(node.getLimit().get(), indent); } - return null; } @@ -302,8 +302,7 @@ protected Void visitQuerySpecification(QuerySpecification node, Integer indent) } if (node.getLimit().isPresent()) { - append(indent, "LIMIT " + node.getLimit().get()) - .append('\n'); + process(node.getLimit().get(), indent); } return null; } @@ -316,6 +315,22 @@ protected Void visitOrderBy(OrderBy node, Integer indent) return null; } + @Override + protected Void visitFetchFirst(FetchFirst node, Integer indent) + { + append(indent, "FETCH FIRST " + node.getRowCount().map(c -> c + " ROWS ONLY").orElse("ROW ONLY")) + .append('\n'); + return null; + } + + @Override + protected Void visitLimit(Limit node, Integer indent) + { + append(indent, "LIMIT " + node.getLimit()) + .append('\n'); + return null; + } + @Override protected Void visitSelect(Select node, Integer indent) { diff --git a/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java b/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java index d90a2b57d17e..fbfe6a5b46d1 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java @@ -68,6 +68,7 @@ import io.prestosql.sql.tree.ExplainType; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.Extract; +import io.prestosql.sql.tree.FetchFirst; import io.prestosql.sql.tree.FrameBound; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.GenericLiteral; @@ -97,6 +98,7 @@ import io.prestosql.sql.tree.Lateral; import io.prestosql.sql.tree.LikeClause; import io.prestosql.sql.tree.LikePredicate; +import io.prestosql.sql.tree.Limit; import io.prestosql.sql.tree.LogicalBinaryExpression; import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.NaturalJoin; @@ -580,10 +582,19 @@ public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context) orderBy = Optional.of(new OrderBy(getLocation(context.ORDER()), visit(context.sortItem(), SortItem.class))); } + Optional limit = Optional.empty(); + if (context.FETCH() != null) { + limit = Optional.of(new FetchFirst(getTextIfPresent(context.fetchFirst))); + } + else if (context.LIMIT() != null) { + limit = Optional.of(new Limit(getTextIfPresent(context.limit).orElseThrow(() -> new IllegalStateException("Missing LIMIT value")))); + } + if (term instanceof QuerySpecification) { // When we have a simple query specification - // followed by order by limit, fold the order by and limit - // clauses into the query specification (analyzer/planner + // followed by order by limit or fetch, + // fold the order by and limit or fetch clauses + // into the query specification (analyzer/planner // expects this structure to resolve references with respect // to columns defined in the query specification) QuerySpecification query = (QuerySpecification) term; @@ -599,7 +610,7 @@ public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context) query.getGroupBy(), query.getHaving(), orderBy, - getTextIfPresent(context.limit)), + limit), Optional.empty(), Optional.empty()); } @@ -609,7 +620,7 @@ public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context) Optional.empty(), term, orderBy, - getTextIfPresent(context.limit)); + limit); } @Override diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/AstVisitor.java b/presto-parser/src/main/java/io/prestosql/sql/tree/AstVisitor.java index 877b5c8b964e..60493df557f0 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/AstVisitor.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/AstVisitor.java @@ -222,6 +222,16 @@ protected R visitOrderBy(OrderBy node, C context) return visitNode(node, context); } + protected R visitFetchFirst(FetchFirst node, C context) + { + return visitNode(node, context); + } + + protected R visitLimit(Limit node, C context) + { + return visitNode(node, context); + } + protected R visitQuerySpecification(QuerySpecification node, C context) { return visitQueryBody(node, context); diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/FetchFirst.java b/presto-parser/src/main/java/io/prestosql/sql/tree/FetchFirst.java new file mode 100644 index 000000000000..b91c8494f26a --- /dev/null +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/FetchFirst.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public class FetchFirst + extends Node +{ + private final Optional rowCount; + + public FetchFirst(String rowCount) + { + this(Optional.empty(), Optional.of(rowCount)); + } + + public FetchFirst(Optional rowCount) + { + this(Optional.empty(), rowCount); + } + + public FetchFirst(NodeLocation location, Optional rowCount) + { + this(Optional.of(location), rowCount); + } + + public FetchFirst(Optional location, Optional rowCount) + { + super(location); + this.rowCount = rowCount; + } + + public Optional getRowCount() + { + return rowCount; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitFetchFirst(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + FetchFirst o = (FetchFirst) obj; + return Objects.equals(rowCount, o.rowCount); + } + + @Override + public int hashCode() + { + return Objects.hash(rowCount); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("rowCount", rowCount.orElse(null)) + .omitNullValues() + .toString(); + } +} diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/Limit.java b/presto-parser/src/main/java/io/prestosql/sql/tree/Limit.java new file mode 100644 index 000000000000..698e0f473389 --- /dev/null +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/Limit.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public class Limit + extends Node +{ + private final String limit; + + public Limit(String limit) + { + this(Optional.empty(), limit); + } + + public Limit(NodeLocation location, String limit) + { + this(Optional.of(location), limit); + } + + public Limit(Optional location, String limit) + { + super(location); + this.limit = limit; + } + + public String getLimit() + { + return limit; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitLimit(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + Limit o = (Limit) obj; + return Objects.equals(limit, o.limit); + } + + @Override + public int hashCode() + { + return Objects.hash(limit); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("limit", limit) + .toString(); + } +} diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/Query.java b/presto-parser/src/main/java/io/prestosql/sql/tree/Query.java index 9f01e2e28c96..dd157e78727b 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/Query.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/Query.java @@ -20,6 +20,7 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class Query @@ -28,13 +29,13 @@ public class Query private final Optional with; private final QueryBody queryBody; private final Optional orderBy; - private final Optional limit; + private final Optional limit; public Query( Optional with, QueryBody queryBody, Optional orderBy, - Optional limit) + Optional limit) { this(Optional.empty(), with, queryBody, orderBy, limit); } @@ -44,7 +45,7 @@ public Query( Optional with, QueryBody queryBody, Optional orderBy, - Optional limit) + Optional limit) { this(Optional.of(location), with, queryBody, orderBy, limit); } @@ -54,13 +55,14 @@ private Query( Optional with, QueryBody queryBody, Optional orderBy, - Optional limit) + Optional limit) { super(location); requireNonNull(with, "with is null"); requireNonNull(queryBody, "queryBody is null"); requireNonNull(orderBy, "orderBy is null"); requireNonNull(limit, "limit is null"); + checkArgument(!limit.isPresent() || limit.get() instanceof FetchFirst || limit.get() instanceof Limit, "limit must be optional of either FetchFirst or Limit type"); this.with = with; this.queryBody = queryBody; @@ -83,7 +85,7 @@ public Optional getOrderBy() return orderBy; } - public Optional getLimit() + public Optional getLimit() { return limit; } @@ -101,6 +103,7 @@ public List getChildren() with.ifPresent(nodes::add); nodes.add(queryBody); orderBy.ifPresent(nodes::add); + limit.ifPresent(nodes::add); return nodes.build(); } diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/QuerySpecification.java b/presto-parser/src/main/java/io/prestosql/sql/tree/QuerySpecification.java index ab395ee6e902..c7793a2e7a67 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/QuerySpecification.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/QuerySpecification.java @@ -20,6 +20,7 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class QuerySpecification @@ -31,7 +32,7 @@ public class QuerySpecification private final Optional groupBy; private final Optional having; private final Optional orderBy; - private final Optional limit; + private final Optional limit; public QuerySpecification( Select select, @@ -40,7 +41,7 @@ public QuerySpecification( Optional groupBy, Optional having, Optional orderBy, - Optional limit) + Optional limit) { this(Optional.empty(), select, from, where, groupBy, having, orderBy, limit); } @@ -53,7 +54,7 @@ public QuerySpecification( Optional groupBy, Optional having, Optional orderBy, - Optional limit) + Optional limit) { this(Optional.of(location), select, from, where, groupBy, having, orderBy, limit); } @@ -66,7 +67,7 @@ private QuerySpecification( Optional groupBy, Optional having, Optional orderBy, - Optional limit) + Optional limit) { super(location); requireNonNull(select, "select is null"); @@ -76,6 +77,11 @@ private QuerySpecification( requireNonNull(having, "having is null"); requireNonNull(orderBy, "orderBy is null"); requireNonNull(limit, "limit is null"); + checkArgument( + !limit.isPresent() + || limit.get() instanceof FetchFirst + || limit.get() instanceof Limit, + "limit must be optional of either FetchFirst or Limit type"); this.select = select; this.from = from; @@ -116,7 +122,7 @@ public Optional getOrderBy() return orderBy; } - public Optional getLimit() + public Optional getLimit() { return limit; } @@ -137,6 +143,7 @@ public List getChildren() groupBy.ifPresent(nodes::add); having.ifPresent(nodes::add); orderBy.ifPresent(nodes::add); + limit.ifPresent(nodes::add); return nodes.build(); } diff --git a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java index 67dafa64edd4..70f8db1c9920 100644 --- a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java @@ -60,6 +60,7 @@ import io.prestosql.sql.tree.ExplainFormat; import io.prestosql.sql.tree.ExplainType; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.FetchFirst; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.GenericLiteral; import io.prestosql.sql.tree.Grant; @@ -82,6 +83,7 @@ import io.prestosql.sql.tree.LambdaExpression; import io.prestosql.sql.tree.Lateral; import io.prestosql.sql.tree.LikeClause; +import io.prestosql.sql.tree.Limit; import io.prestosql.sql.tree.LogicalBinaryExpression; import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.NaturalJoin; @@ -531,8 +533,36 @@ public void testBetween() } @Test - public void testLimitAll() + public void testSelectWithLimit() { + assertStatement("SELECT * FROM table1 LIMIT 2", + new Query( + Optional.empty(), + new QuerySpecification( + selectList(new AllColumns()), + Optional.of(new Table(QualifiedName.of("table1"))), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new Limit("2"))), + Optional.empty(), + Optional.empty())); + + assertStatement("SELECT * FROM table1 LIMIT ALL", + new Query( + Optional.empty(), + new QuerySpecification( + selectList(new AllColumns()), + Optional.of(new Table(QualifiedName.of("table1"))), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new Limit("ALL"))), + Optional.empty(), + Optional.empty())); + Query valuesQuery = query(values( row(new LongLiteral("1"), new StringLiteral("1")), row(new LongLiteral("2"), new StringLiteral("2")))); @@ -544,7 +574,7 @@ public void testLimitAll() Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of("ALL"))); + Optional.of(new Limit("ALL")))); } @Test @@ -872,6 +902,51 @@ public void testSelectWithOrderBy() Optional.empty())); } + @Test + public void testSelectWithFetch() + { + assertStatement("SELECT * FROM table1 FETCH FIRST 2 ROWS ONLY", + new Query( + Optional.empty(), + new QuerySpecification( + selectList(new AllColumns()), + Optional.of(new Table(QualifiedName.of("table1"))), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new FetchFirst("2"))), + Optional.empty(), + Optional.empty())); + + assertStatement("SELECT * FROM table1 FETCH NEXT ROW ONLY", + new Query( + Optional.empty(), + new QuerySpecification( + selectList(new AllColumns()), + Optional.of(new Table(QualifiedName.of("table1"))), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new FetchFirst(Optional.empty()))), + Optional.empty(), + Optional.empty())); + + Query valuesQuery = query(values( + row(new LongLiteral("1"), new StringLiteral("1")), + row(new LongLiteral("2"), new StringLiteral("2")))); + + assertStatement("SELECT * FROM (VALUES (1, '1'), (2, '2')) FETCH FIRST ROW ONLY", + simpleQuery(selectList(new AllColumns()), + subquery(valuesQuery), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new FetchFirst(Optional.empty())))); + } + @Test public void testSelectWithGroupBy() { diff --git a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParserErrorHandling.java b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParserErrorHandling.java index ef769b0b1400..fa27d3ee6162 100644 --- a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParserErrorHandling.java +++ b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParserErrorHandling.java @@ -49,7 +49,7 @@ public Object[][] getStatements() {"select * from 'oops", "line 1:15: mismatched input '''. Expecting: '(', 'LATERAL', 'UNNEST', "}, {"select *\nfrom x\nfrom", - "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'CROSS', 'EXCEPT', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', 'NATURAL', " + + "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'CROSS', 'EXCEPT', 'FETCH', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', 'NATURAL', " + "'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', , "}, {"select *\nfrom x\nwhere from", "line 3:7: mismatched input 'from'. Expecting: "}, @@ -114,7 +114,7 @@ public Object[][] getStatements() {"SELECT foo(*) filter (", "line 1:23: mismatched input ''. Expecting: 'WHERE'"}, {"SELECT * FROM t t x", - "line 1:19: mismatched input 'x'. Expecting: '(', ',', 'CROSS', 'EXCEPT', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', 'NATURAL', 'ORDER', " + + "line 1:19: mismatched input 'x'. Expecting: '(', ',', 'CROSS', 'EXCEPT', 'FETCH', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', 'LIMIT', 'NATURAL', 'ORDER', " + "'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', "}, {"SELECT * FROM t WHERE EXISTS (", "line 1:31: mismatched input ''. Expecting: "}, diff --git a/presto-verifier/src/main/java/io/prestosql/verifier/QueryRewriter.java b/presto-verifier/src/main/java/io/prestosql/verifier/QueryRewriter.java index 7b767dd644f7..211f7a98a05d 100644 --- a/presto-verifier/src/main/java/io/prestosql/verifier/QueryRewriter.java +++ b/presto-verifier/src/main/java/io/prestosql/verifier/QueryRewriter.java @@ -29,6 +29,7 @@ import io.prestosql.sql.tree.Identifier; import io.prestosql.sql.tree.Insert; import io.prestosql.sql.tree.LikeClause; +import io.prestosql.sql.tree.Limit; import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.QueryBody; @@ -203,12 +204,12 @@ private List getColumns(Connection connection, CreateTableAsSelect creat querySpecification.getGroupBy(), querySpecification.getHaving(), querySpecification.getOrderBy(), - Optional.of("0")); + Optional.of(new Limit("0"))); zeroRowsQuery = new io.prestosql.sql.tree.Query(createSelectClause.getWith(), innerQuery, Optional.empty(), Optional.empty()); } else { - zeroRowsQuery = new io.prestosql.sql.tree.Query(createSelectClause.getWith(), innerQuery, Optional.empty(), Optional.of("0")); + zeroRowsQuery = new io.prestosql.sql.tree.Query(createSelectClause.getWith(), innerQuery, Optional.empty(), Optional.of(new Limit("0"))); } ImmutableList.Builder columns = ImmutableList.builder();