Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -97,6 +98,7 @@ public class Analysis
private final Map<NodeRef<Node>, List<Expression>> outputExpressions = new LinkedHashMap<>();
private final Map<NodeRef<QuerySpecification>, List<FunctionCall>> windowFunctions = new LinkedHashMap<>();
private final Map<NodeRef<OrderBy>, List<FunctionCall>> orderByWindowFunctions = new LinkedHashMap<>();
private final Map<NodeRef<Node>, OptionalLong> limit = new LinkedHashMap<>();

private final Map<NodeRef<Join>, Expression> joins = new LinkedHashMap<>();
private final Map<NodeRef<Join>, JoinUsingAnalysis> joinUsing = new LinkedHashMap<>();
Expand Down Expand Up @@ -316,6 +318,22 @@ public List<Expression> getOrderByExpressions(Node node)
return orderByExpressions.get(NodeRef.of(node));
}

public void setLimit(Node node, OptionalLong rowCount)
{
limit.put(NodeRef.of(node), rowCount);
}

public void setLimit(Node node, long rowCount)
{
limit.put(NodeRef.of(node), OptionalLong.of(rowCount));
}

public OptionalLong getLimit(Node node)
{
checkState(limit.containsKey(NodeRef.of(node)), "missing LIMIT value for node %s", node);
return limit.get(NodeRef.of(node));
}

public void setOutputExpressions(Node node, List<Expression> expressions)
{
outputExpressions.put(NodeRef.of(node), ImmutableList.copyOf(expressions));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,7 @@ public enum SemanticErrorCode
MISSING_ROLE,

TOO_MANY_GROUPING_SETS,

INVALID_FETCH_FIRST_ROW_COUNT,
INVALID_LIMIT_ROW_COUNT,
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.FetchFirst;
import io.prestosql.sql.tree.FieldReference;
import io.prestosql.sql.tree.FrameBound;
import io.prestosql.sql.tree.FunctionCall;
Expand All @@ -94,6 +95,7 @@
import io.prestosql.sql.tree.JoinOn;
import io.prestosql.sql.tree.JoinUsing;
import io.prestosql.sql.tree.Lateral;
import io.prestosql.sql.tree.Limit;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.NaturalJoin;
import io.prestosql.sql.tree.Node;
Expand Down Expand Up @@ -141,6 +143,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -176,6 +179,8 @@
import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME;
import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_PROPERTY;
import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_FETCH_FIRST_ROW_COUNT;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LIMIT_ROW_COUNT;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_WINDOW_FRAME;
Expand Down Expand Up @@ -735,6 +740,10 @@ protected Scope visitQuery(Query node, Optional<Scope> scope)
analysis.setOrderByExpressions(node, emptyList());
}

if (node.getLimit().isPresent()) {
analyzeLimit(node.getLimit().get());
}

// Input fields == Output fields
analysis.setOutputExpressions(node, descriptorToFields(queryBodyScope));

Expand Down Expand Up @@ -1028,6 +1037,10 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional<Scope>
analysis.setOrderByExpressions(node, emptyList());
}

if (node.getLimit().isPresent()) {
analyzeLimit(node.getLimit().get());
}

List<Expression> sourceExpressions = new ArrayList<>(outputExpressions);
node.getHaving().ifPresent(sourceExpressions::add);

Expand Down Expand Up @@ -2113,6 +2126,56 @@ private List<Expression> analyzeOrderBy(Node node, List<SortItem> sortItems, Sco
return orderByFields;
}

private void analyzeLimit(Node node)
{
checkState(
node instanceof FetchFirst || node instanceof Limit,
"Invalid limit node type. Expected: FetchFirst or Limit. Actual: %s", node.getClass().getName());
if (node instanceof FetchFirst) {
analyzeLimit((FetchFirst) node);
}
else {
analyzeLimit((Limit) node);
}
}

private void analyzeLimit(FetchFirst node)
{
if (!node.getRowCount().isPresent()) {
analysis.setLimit(node, 1);
}
else {
long rowCount;
try {
rowCount = Long.parseLong(node.getRowCount().get());
}
catch (NumberFormatException e) {
throw new SemanticException(INVALID_FETCH_FIRST_ROW_COUNT, node, "Invalid FETCH FIRST row count: %s", node.getRowCount().get());
}
if (rowCount <= 0) {
throw new SemanticException(INVALID_FETCH_FIRST_ROW_COUNT, node, "FETCH FIRST row count must be positive (actual value: %s)", rowCount);
}
analysis.setLimit(node, rowCount);
}
}

private void analyzeLimit(Limit node)
{
if (node.getLimit().equalsIgnoreCase("all")) {
analysis.setLimit(node, OptionalLong.empty());
}
else {
long rowCount;
try {
rowCount = Long.parseLong(node.getLimit());
}
catch (NumberFormatException e) {
throw new SemanticException(INVALID_LIMIT_ROW_COUNT, node, "Invalid LIMIT row count: %s", node.getLimit());
}
analysis.setLimit(node, rowCount);
}
}

private Scope createAndAssignScope(Node node, Optional<Scope> parentScope)
{
return createAndAssignScope(node, parentScope, emptyList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ private PlanBuilder sort(PlanBuilder subPlan, QuerySpecification node)
return sort(subPlan, node.getOrderBy(), node.getLimit(), analysis.getOrderByExpressions(node));
}

private PlanBuilder sort(PlanBuilder subPlan, Optional<OrderBy> orderBy, Optional<String> limit, List<Expression> orderByExpressions)
private PlanBuilder sort(PlanBuilder subPlan, Optional<OrderBy> orderBy, Optional<Node> limit, List<Expression> orderByExpressions)
{
if (!orderBy.isPresent()) {
return subPlan;
Expand All @@ -887,8 +887,8 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional<OrderBy> orderBy, Optiona

PlanNode planNode;
OrderingScheme orderingScheme = new OrderingScheme(orderBySymbols.build(), orderings);
if (limit.isPresent() && !limit.get().equalsIgnoreCase("all")) {
planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), Long.parseLong(limit.get()), orderingScheme, TopNNode.Step.SINGLE);
if (limit.isPresent() && analysis.getLimit(limit.get()).isPresent()) {
planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), analysis.getLimit(limit.get()).getAsLong(), orderingScheme, TopNNode.Step.SINGLE);
}
else {
planNode = new SortNode(idAllocator.getNextId(), subPlan.getRoot(), orderingScheme);
Expand All @@ -907,13 +907,10 @@ private PlanBuilder limit(PlanBuilder subPlan, QuerySpecification node)
return limit(subPlan, node.getOrderBy(), node.getLimit());
}

private PlanBuilder limit(PlanBuilder subPlan, Optional<OrderBy> orderBy, Optional<String> limit)
private PlanBuilder limit(PlanBuilder subPlan, Optional<OrderBy> orderBy, Optional<Node> limit)
{
if (!orderBy.isPresent() && limit.isPresent()) {
if (!limit.get().equalsIgnoreCase("all")) {
long limitValue = Long.parseLong(limit.get());
subPlan = subPlan.withNewRoot(new LimitNode(idAllocator.getNextId(), subPlan.getRoot(), limitValue, false));
}
if (!orderBy.isPresent() && limit.isPresent() && analysis.getLimit(limit.get()).isPresent()) {
return subPlan.withNewRoot(new LimitNode(idAllocator.getNextId(), subPlan.getRoot(), analysis.getLimit(limit.get()).getAsLong(), false));
}

return subPlan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.prestosql.sql.tree.AstVisitor;
import io.prestosql.sql.tree.DescribeInput;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.Limit;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.Node;
import io.prestosql.sql.tree.NullLiteral;
Expand Down Expand Up @@ -113,10 +114,10 @@ protected Node visitDescribeInput(DescribeInput node, Void context)

// return the positions and types of all parameters
Row[] rows = parameters.stream().map(parameter -> createDescribeInputRow(parameter, analysis)).toArray(Row[]::new);
Optional<String> limit = Optional.empty();
Optional<Node> limit = Optional.empty();
if (rows.length == 0) {
rows = new Row[] {row(new NullLiteral(), new NullLiteral())};
limit = Optional.of("0");
limit = Optional.of(new Limit("0"));
}

return simpleQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.DescribeOutput;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.Limit;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.Node;
import io.prestosql.sql.tree.NullLiteral;
Expand Down Expand Up @@ -103,12 +104,12 @@ protected Node visitDescribeOutput(DescribeOutput node, Void context)
Analyzer analyzer = new Analyzer(session, metadata, parser, accessControl, queryExplainer, parameters, warningCollector);
Analysis analysis = analyzer.analyze(statement, true);

Optional<String> limit = Optional.empty();
Optional<Node> limit = Optional.empty();
Row[] rows = analysis.getRootScope().getRelationType().getVisibleFields().stream().map(field -> createDescribeOutputRow(field, analysis)).toArray(Row[]::new);
if (rows.length == 0) {
NullLiteral nullLiteral = new NullLiteral();
rows = new Row[] {row(nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral)};
limit = Optional.of("0");
limit = Optional.of(new Limit("0"));
}
return simpleQuery(
selectList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME;
import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_PROPERTY;
import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_FETCH_FIRST_ROW_COUNT;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LIMIT_ROW_COUNT;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LITERAL;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE;
Expand Down Expand Up @@ -329,6 +331,19 @@ public void testOrderByNonComparable()
assertFails(TYPE_MISMATCH, "SELECT x FROM (SELECT approx_set(1) x) ORDER BY x");
}

@Test
public void testFetchFirstInvalidRowCount()
{
assertFails(INVALID_FETCH_FIRST_ROW_COUNT, "SELECT * FROM t1 FETCH FIRST 987654321098765432109876543210 ROWS ONLY");
assertFails(INVALID_FETCH_FIRST_ROW_COUNT, "SELECT * FROM t1 FETCH FIRST 0 ROWS ONLY");
}

@Test
public void testLimitInvalidRowCount()
{
assertFails(INVALID_LIMIT_ROW_COUNT, "SELECT * FROM t1 LIMIT 987654321098765432109876543210");
}

@Test
public void testNestedAggregation()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.functionCall;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.limit;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.markDistinct;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.node;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output;
Expand Down Expand Up @@ -833,4 +834,30 @@ public void testRemoveAggregationInSemiJoin()
"SELECT custkey FROM orders WHERE custkey IN (SELECT distinct custkey FROM customer)",
AggregationNode.class);
}

@Test
public void testOrderByFetch()
{
assertPlan(
"SELECT * FROM nation ORDER BY name FETCH FIRST 2 ROWS ONLY",
anyTree(
topN(
2,
ImmutableList.of(sort("NAME", ASCENDING, LAST)),
TopNNode.Step.PARTIAL,
tableScan("nation", ImmutableMap.of(
"NAME", "name")))));
}

@Test
public void testFetch()
{
assertPlan(
"SELECT * FROM nation FETCH FIRST 2 ROWS ONLY",
anyTree(
limit(
2,
any(
tableScan("nation")))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ property
queryNoWith:
queryTerm
(ORDER BY sortItem (',' sortItem)*)?
(LIMIT limit=(INTEGER_VALUE | ALL))?
((LIMIT limit=(INTEGER_VALUE | ALL)) | (FETCH (FIRST | NEXT) (fetchFirst=INTEGER_VALUE)? (ROW | ROWS) ONLY))?
;

queryTerm
Expand Down Expand Up @@ -492,14 +492,14 @@ nonReserved
| CALL | CASCADE | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CURRENT
| DATA | DATE | DAY | DEFINER | DESC | DISTRIBUTED
| EXCLUDING | EXPLAIN
| FILTER | FIRST | FOLLOWING | FORMAT | FUNCTIONS
| FETCH | FILTER | FIRST | FOLLOWING | FORMAT | FUNCTIONS
| GRANT | GRANTED | GRANTS | GRAPHVIZ
| HOUR
| IF | INCLUDING | INPUT | INTERVAL | INVOKER | IO | ISOLATION
| JSON
| LAST | LATERAL | LEVEL | LIMIT | LOGICAL
| MAP | MINUTE | MONTH
| NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS
| NEXT | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS
| ONLY | OPTION | ORDINALITY | OUTPUT | OVER
| PARTITION | PARTITIONS | PATH | POSITION | PRECEDING | PRIVILEGES | PROPERTIES
| RANGE | READ | RENAME | REPEATABLE | REPLACE | RESET | RESTRICT | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS
Expand Down Expand Up @@ -569,6 +569,7 @@ EXISTS: 'EXISTS';
EXPLAIN: 'EXPLAIN';
EXTRACT: 'EXTRACT';
FALSE: 'FALSE';
FETCH: 'FETCH';
FILTER: 'FILTER';
FIRST: 'FIRST';
FOLLOWING: 'FOLLOWING';
Expand Down Expand Up @@ -613,6 +614,7 @@ MAP: 'MAP';
MINUTE: 'MINUTE';
MONTH: 'MONTH';
NATURAL: 'NATURAL';
NEXT: 'NEXT';
NFC : 'NFC';
NFD : 'NFD';
NFKC : 'NFKC';
Expand Down
3 changes: 2 additions & 1 deletion presto-parser/src/main/java/io/prestosql/sql/QueryUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.prestosql.sql.tree.GroupBy;
import io.prestosql.sql.tree.Identifier;
import io.prestosql.sql.tree.LogicalBinaryExpression;
import io.prestosql.sql.tree.Node;
import io.prestosql.sql.tree.OrderBy;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.Query;
Expand Down Expand Up @@ -199,7 +200,7 @@ public static Query simpleQuery(Select select, Relation from, Optional<Expressio
return simpleQuery(select, from, where, Optional.empty(), Optional.empty(), orderBy, Optional.empty());
}

public static Query simpleQuery(Select select, Relation from, Optional<Expression> where, Optional<GroupBy> groupBy, Optional<Expression> having, Optional<OrderBy> orderBy, Optional<String> limit)
public static Query simpleQuery(Select select, Relation from, Optional<Expression> where, Optional<GroupBy> groupBy, Optional<Expression> having, Optional<OrderBy> orderBy, Optional<Node> limit)
{
return query(new QuerySpecification(
select,
Expand Down
Loading