Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ groupBy

groupingElement
: groupingSet #singleGroupingSet
| AUTO #auto
| ROLLUP '(' (groupingSet (',' groupingSet)*)? ')' #rollup
| CUBE '(' (groupingSet (',' groupingSet)*)? ')' #cube
| GROUPING SETS '(' groupingSet (',' groupingSet)* ')' #multipleGroupingSets
Expand Down Expand Up @@ -1043,6 +1044,7 @@ AS: 'AS';
ASC: 'ASC';
AT: 'AT';
AUTHORIZATION: 'AUTHORIZATION';
AUTO: 'AUTO';
BEGIN: 'BEGIN';
BERNOULLI: 'BERNOULLI';
BETWEEN: 'BETWEEN';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public void test()
"ASC",
"AT",
"AUTHORIZATION",
"AUTO",
"BEGIN",
"BERNOULLI",
"BETWEEN",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
import io.trino.sql.tree.AllRows;
import io.trino.sql.tree.Analyze;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.AutoGroupBy;
import io.trino.sql.tree.Call;
import io.trino.sql.tree.CallArgument;
import io.trino.sql.tree.ColumnDefinition;
Expand Down Expand Up @@ -4417,7 +4418,7 @@ private void checkGroupingSetsCount(GroupBy node)
for (GroupingElement element : node.getGroupingElements()) {
try {
int product;
if (element instanceof SimpleGroupBy) {
if (element instanceof SimpleGroupBy || element instanceof AutoGroupBy) {
product = 1;
}
else if (element instanceof GroupingSets groupingSets) {
Expand Down Expand Up @@ -4489,6 +4490,32 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope,
groupingExpressions.add(column);
}
}
else if (groupingElement instanceof AutoGroupBy) {
// Analyze non-aggregation outputs
for (Expression column : outputExpressions) {
if (column instanceof FunctionCall functionCall) {
ResolvedFunction function = getResolvedFunction(functionCall);
if (function.functionKind() == AGGREGATE) {
continue;
}
}
else {
verifyNoAggregateWindowOrGroupingFunctions(session, functionResolver, accessControl, column, "GROUP BY clause");
analyzeExpression(column, scope);
}

ResolvedField field = analysis.getColumnReferenceFields().get(NodeRef.of(column));
if (field != null) {
sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId())));
}
else {
analysis.recordSubqueries(node, analyzeExpression(column, scope));
complexExpressions.add(column);
}

groupingExpressions.add(column);
}
}
else if (groupingElement instanceof GroupingSets element) {
for (Expression column : groupingElement.getExpressions()) {
analyzeExpression(column, scope);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3174,6 +3174,10 @@ public void testGroupBy()
{
// TODO: validate output
analyze("SELECT a, SUM(b) FROM t1 GROUP BY a");
analyze("SELECT a, SUM(b) FROM t1 GROUP BY AUTO");
analyze("SELECT a as x, SUM(b) FROM t1 GROUP BY AUTO");
analyze("SELECT a, SUM(b) FROM t1 GROUP BY ALL AUTO");
analyze("SELECT a as x, SUM(b) FROM t1 GROUP BY DISTINCT AUTO");
}

@Test
Expand Down
208 changes: 208 additions & 0 deletions core/trino-main/src/test/java/io/trino/sql/query/TestGroupBy.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,212 @@ public void testGroupByRepeatedOrdinals()
"SELECT null GROUP BY 1, 1"))
.matches("VALUES null");
}

@Test
void testGroupByAuto()
{
assertThat(assertions.query(
"""
SELECT *
FROM (VALUES 1) t(a)
GROUP BY AUTO
"""))
.matches("VALUES 1");

assertThat(assertions.query(
"""
SELECT *
FROM (VALUES 1, 2) t(a)
GROUP BY AUTO
"""))
.matches("VALUES 1, 2");

assertThat(assertions.query(
"""
SELECT sum(a)
FROM (VALUES (1), (2)) t(a)
GROUP BY AUTO
"""))
.matches("VALUES BIGINT '3'");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a AS new_a, sum(b) AS sum_b
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a + 1, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO
"""))
.matches("VALUES (2, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT abs(a), sum(b)
FROM (VALUES (-1, 10), (-1, 20)) t(a, b)
GROUP BY AUTO
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT sum(b), a
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO
"""))
.matches("VALUES (BIGINT '30', 1)");

assertThat(assertions.query(
"""
SELECT sum(a)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO
"""))
.matches("VALUES (BIGINT '2')");

// ALL AUTO
assertThat(assertions.query(
"""
SELECT sum(a)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY ALL AUTO
"""))
.matches("VALUES (BIGINT '2')");

// DISTINCT AUTO
assertThat(assertions.query(
"""
SELECT sum(a)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY DISTINCT AUTO
"""))
.matches("VALUES (BIGINT '2')");

// ROLLUP
assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO, ROLLUP(b)
"""))
.matches("VALUES (1, BIGINT '10'), (1, BIGINT '20'), (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY ALL AUTO, ROLLUP(b)
"""))
.matches("VALUES (1, BIGINT '10'), (1, BIGINT '20'), (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY DISTINCT AUTO, ROLLUP(a)
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a, b, c, sum(b)
FROM (VALUES (1, 1, 1, 1), (1, 1, 1, 2), (2, 2, 2, 3)) t(a, b, c, d)
GROUP BY AUTO, ROLLUP(a)
"""))
.matches(
"""
SELECT a, b, c, sum(b)
FROM (VALUES (1, 1, 1, 1), (1, 1, 1, 2), (2, 2, 2, 3)) t(a, b, c, d)
GROUP BY (a, b, c), ROLLUP(a)
""");

// CUBE
assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO, CUBE(b)
"""))
.matches("VALUES (1, BIGINT '10'), (1, BIGINT '20'), (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY ALL AUTO, CUBE(b)
"""))
.matches("VALUES (1, BIGINT '10'), (1, BIGINT '20'), (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY DISTINCT AUTO, CUBE(a)
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a, b, c, sum(b)
FROM (VALUES (1, 1, 1, 1), (1, 1, 1, 2), (2, 2, 2, 3)) t(a, b, c, d)
GROUP BY AUTO, CUBE(a)
"""))
.matches(
"""
SELECT a, b, c, sum(b)
FROM (VALUES (1, 1, 1, 1), (1, 1, 1, 2), (2, 2, 2, 3)) t(a, b, c, d)
GROUP BY (a, b, c), CUBE(a)
""");

// GROUPING SETS
assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY AUTO, GROUPING SETS((b))
"""))
.matches("VALUES (1, BIGINT '10'), (1, BIGINT '20')");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY ALL AUTO, GROUPING SETS((b))
"""))
.matches("VALUES (1, BIGINT '10'), (1, BIGINT '20')");

assertThat(assertions.query(
"""
SELECT a, sum(b)
FROM (VALUES (1, 10), (1, 20)) t(a, b)
GROUP BY DISTINCT AUTO, GROUPING SETS((a))
"""))
.matches("VALUES (1, BIGINT '30')");

assertThat(assertions.query(
"""
SELECT a, b, c, sum(b)
FROM (VALUES (1, 1, 1, 1), (1, 1, 1, 2), (2, 2, 2, 3)) t(a, b, c, d)
GROUP BY AUTO, GROUPING SETS((a))
"""))
.matches(
"""
SELECT a, b, c, sum(b)
FROM (VALUES (1, 1, 1, 1), (1, 1, 1, 2), (2, 2, 2, 3)) t(a, b, c, d)
GROUP BY (a, b, c), GROUPING SETS((a))
""");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.sql.tree.Array;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.AtTimeZone;
import io.trino.sql.tree.AutoGroupBy;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.BinaryLiteral;
import io.trino.sql.tree.BooleanLiteral;
Expand Down Expand Up @@ -1131,6 +1132,9 @@ static String formatGroupBy(List<GroupingElement> groupingElements)
result = formatGroupingSet(columns);
}
}
else if (groupingElement instanceof AutoGroupBy) {
result = "AUTO";
}
else if (groupingElement instanceof GroupingSets groupingSets) {
String type = switch (groupingSets.getType()) {
case EXPLICIT -> "GROUPING SETS";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.trino.sql.tree.Array;
import io.trino.sql.tree.AssignmentStatement;
import io.trino.sql.tree.AtTimeZone;
import io.trino.sql.tree.AutoGroupBy;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.BinaryLiteral;
import io.trino.sql.tree.BooleanLiteral;
Expand Down Expand Up @@ -1308,6 +1309,12 @@ public Node visitSingleGroupingSet(SqlBaseParser.SingleGroupingSetContext contex
return new SimpleGroupBy(getLocation(context), visit(context.groupingSet().expression(), Expression.class));
}

@Override
public Node visitAuto(SqlBaseParser.AutoContext context)
{
return new AutoGroupBy(getLocation(context));
}

@Override
public Node visitRollup(SqlBaseParser.RollupContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,11 @@ protected R visitSimpleGroupBy(SimpleGroupBy node, C context)
return visitGroupingElement(node, context);
}

protected R visitAutoGroupBy(AutoGroupBy node, C context)
{
return visitGroupingElement(node, context);
}

protected R visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, C context)
{
return visitExpression(node, context);
Expand Down
Loading
Loading