Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -1473,15 +1473,15 @@ public static class GroupingSetAnalysis
{
private final List<Expression> originalExpressions;

private final List<Set<FieldId>> cubes;
private final List<List<FieldId>> rollups;
private final List<List<Set<FieldId>>> cubes;
private final List<List<Set<FieldId>>> rollups;
private final List<List<Set<FieldId>>> ordinarySets;
private final List<Expression> complexExpressions;

public GroupingSetAnalysis(
List<Expression> originalExpressions,
List<Set<FieldId>> cubes,
List<List<FieldId>> rollups,
List<List<Set<FieldId>>> cubes,
List<List<Set<FieldId>>> rollups,
List<List<Set<FieldId>>> ordinarySets,
List<Expression> complexExpressions)
{
Expand All @@ -1497,12 +1497,12 @@ public List<Expression> getOriginalExpressions()
return originalExpressions;
}

public List<Set<FieldId>> getCubes()
public List<List<Set<FieldId>>> getCubes()
{
return cubes;
}

public List<List<FieldId>> getRollups()
public List<List<Set<FieldId>>> getRollups()
{
return rollups;
}
Expand All @@ -1520,8 +1520,12 @@ public List<Expression> getComplexExpressions()
public Set<FieldId> getAllFields()
{
return Streams.concat(
cubes.stream().flatMap(Collection::stream),
rollups.stream().flatMap(Collection::stream),
cubes.stream()
.flatMap(Collection::stream)
.flatMap(Collection::stream),
rollups.stream()
.flatMap(Collection::stream)
.flatMap(Collection::stream),
ordinarySets.stream()
.flatMap(Collection::stream)
.flatMap(Collection::stream))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@
import io.trino.sql.tree.CreateTable;
import io.trino.sql.tree.CreateTableAsSelect;
import io.trino.sql.tree.CreateView;
import io.trino.sql.tree.Cube;
import io.trino.sql.tree.Deallocate;
import io.trino.sql.tree.Delete;
import io.trino.sql.tree.Deny;
Expand Down Expand Up @@ -203,7 +202,6 @@
import io.trino.sql.tree.ResetSession;
import io.trino.sql.tree.Revoke;
import io.trino.sql.tree.Rollback;
import io.trino.sql.tree.Rollup;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.RowPattern;
import io.trino.sql.tree.SampledRelation;
Expand Down Expand Up @@ -3975,18 +3973,18 @@ private void checkGroupingSetsCount(GroupBy node)
if (element instanceof SimpleGroupBy) {
product = 1;
}
else if (element instanceof Cube) {
int exponent = element.getExpressions().size();
if (exponent > 30) {
throw new ArithmeticException();
}
product = 1 << exponent;
}
else if (element instanceof Rollup) {
product = element.getExpressions().size() + 1;
}
else if (element instanceof GroupingSets) {
product = ((GroupingSets) element).getSets().size();
else if (element instanceof GroupingSets groupingSets) {
product = switch (groupingSets.getType()) {
case CUBE -> {
int exponent = ((GroupingSets) element).getSets().size();
if (exponent > 30) {
throw new ArithmeticException();
Comment thread
martint marked this conversation as resolved.
Outdated
}
yield 1 << exponent;
}
case ROLLUP -> groupingSets.getSets().size() + 1;
case EXPLICIT -> groupingSets.getSets().size();
};
}
else {
throw new UnsupportedOperationException("Unsupported grouping element type: " + element.getClass().getName());
Expand All @@ -4007,8 +4005,8 @@ else if (element instanceof GroupingSets) {
private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope, List<Expression> outputExpressions)
{
if (node.getGroupBy().isPresent()) {
ImmutableList.Builder<Set<FieldId>> cubes = ImmutableList.builder();
ImmutableList.Builder<List<FieldId>> rollups = ImmutableList.builder();
ImmutableList.Builder<List<Set<FieldId>>> cubes = ImmutableList.builder();
ImmutableList.Builder<List<Set<FieldId>>> rollups = ImmutableList.builder();
ImmutableList.Builder<List<Set<FieldId>>> sets = ImmutableList.builder();
ImmutableList.Builder<Expression> complexExpressions = ImmutableList.builder();
ImmutableList.Builder<Expression> groupingExpressions = ImmutableList.builder();
Expand Down Expand Up @@ -4044,7 +4042,7 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope,
groupingExpressions.add(column);
}
}
else {
else if (groupingElement instanceof GroupingSets element) {
for (Expression column : groupingElement.getExpressions()) {
analyzeExpression(column, scope);
if (!analysis.getColumnReferences().contains(NodeRef.of(column))) {
Expand All @@ -4054,34 +4052,18 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope,
groupingExpressions.add(column);
}

if (groupingElement instanceof Cube) {
Set<FieldId> cube = groupingElement.getExpressions().stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableSet());

cubes.add(cube);
}
else if (groupingElement instanceof Rollup) {
List<FieldId> rollup = groupingElement.getExpressions().stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableList());

rollups.add(rollup);
}
else if (groupingElement instanceof GroupingSets) {
List<Set<FieldId>> groupingSets = ((GroupingSets) groupingElement).getSets().stream()
.map(set -> set.stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableSet()))
.collect(toImmutableList());

sets.add(groupingSets);
List<Set<FieldId>> groupingSets = element.getSets().stream()
.map(set -> set.stream()
.map(NodeRef::of)
.map(analysis.getColumnReferenceFields()::get)
.map(ResolvedField::getFieldId)
.collect(toImmutableSet()))
.collect(toImmutableList());

switch (element.getType()) {
case CUBE -> cubes.add(groupingSets);
case ROLLUP -> rollups.add(groupingSets);
case EXPLICIT -> sets.add(groupingSets);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
Expand Down Expand Up @@ -1274,13 +1275,21 @@ private static List<Set<FieldId>> enumerateGroupingSets(GroupingSetAnalysis grou
{
List<List<Set<FieldId>>> partialSets = new ArrayList<>();

for (Set<FieldId> cube : groupingSetAnalysis.getCubes()) {
partialSets.add(ImmutableList.copyOf(Sets.powerSet(cube)));
for (List<Set<FieldId>> cube : groupingSetAnalysis.getCubes()) {
List<Set<FieldId>> sets = Sets.powerSet(ImmutableSet.copyOf(cube)).stream()
.map(set -> set.stream()
.flatMap(Collection::stream)
.collect(toImmutableSet()))
.collect(toImmutableList());

partialSets.add(sets);
}

for (List<FieldId> rollup : groupingSetAnalysis.getRollups()) {
for (List<Set<FieldId>> rollup : groupingSetAnalysis.getRollups()) {
List<Set<FieldId>> sets = IntStream.rangeClosed(0, rollup.size())
.mapToObj(i -> ImmutableSet.copyOf(rollup.subList(0, i)))
.mapToObj(prefixLength -> rollup.subList(0, prefixLength).stream()
.flatMap(Collection::stream)
.collect(Collectors.toSet()))
.collect(toImmutableList());

partialSets.add(sets);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,40 @@ public void testRollupAggregationWithOrderedLimit()
"ORDER BY a LIMIT 2"))
.matches("VALUES 1, 2");
}

@Test
public void testComplexCube()
{
assertThat(assertions.query("""
SELECT a, b, c, count(*)
FROM (VALUES (1, 1, 1), (1, 2, 2), (1, 2, 2)) t(a, b, c)
GROUP BY CUBE (a, (b, c))
"""))
.matches("""
VALUES
( 1, 1, 1, BIGINT '1'),
( 1, 2, 2, 2),
( 1, NULL, NULL, 3),
(NULL, NULL, NULL, 3),
(NULL, 1, 1, 1),
(NULL, 2, 2, 2)
""");
}

@Test
public void testComplexRollup()
{
assertThat(assertions.query("""
SELECT a, b, c, count(*)
FROM (VALUES (1, 1, 1), (1, 2, 2), (1, 2, 2)) t(a, b, c)
GROUP BY ROLLUP (a, (b, c))
"""))
.matches("""
VALUES
( 1, 1, 1, BIGINT '1'),
(NULL, NULL, NULL, 3),
( 1, NULL, NULL, 3),
( 1, 2, 2, 2)
""");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ groupBy

groupingElement
: groupingSet #singleGroupingSet
| ROLLUP '(' (expression (',' expression)*)? ')' #rollup
| CUBE '(' (expression (',' expression)*)? ')' #cube
| ROLLUP '(' (groupingSet (',' groupingSet)*)? ')' #rollup
| CUBE '(' (groupingSet (',' groupingSet)*)? ')' #cube
| GROUPING SETS '(' groupingSet (',' groupingSet)* ')' #multipleGroupingSets
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.trino.sql.tree.CharLiteral;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Cube;
import io.trino.sql.tree.CurrentCatalog;
import io.trino.sql.tree.CurrentPath;
import io.trino.sql.tree.CurrentSchema;
Expand Down Expand Up @@ -84,7 +83,6 @@
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.QuantifiedComparisonExpression;
import io.trino.sql.tree.Rollup;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.RowDataType;
import io.trino.sql.tree.SearchedCaseExpression;
Expand Down Expand Up @@ -1230,15 +1228,24 @@ static String formatGroupBy(List<GroupingElement> groupingElements)
}
}
else if (groupingElement instanceof GroupingSets) {
String type;
switch (((GroupingSets) groupingElement).getType()) {
case EXPLICIT:
type = "GROUPING SETS";
break;
case CUBE:
type = "CUBE";
break;
case ROLLUP:
type = "ROLLUP";
break;
default:
throw new UnsupportedOperationException();
}

result = ((GroupingSets) groupingElement).getSets().stream()
.map(ExpressionFormatter::formatGroupingSet)
.collect(joining(", ", "GROUPING SETS (", ")"));
}
else if (groupingElement instanceof Cube) {
result = "CUBE " + formatGroupingSet(groupingElement.getExpressions());
}
else if (groupingElement instanceof Rollup) {
result = "ROLLUP " + formatGroupingSet(groupingElement.getExpressions());
.collect(joining(", ", type + " (", ")"));
}
return result;
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import io.trino.sql.tree.CreateTable;
import io.trino.sql.tree.CreateTableAsSelect;
import io.trino.sql.tree.CreateView;
import io.trino.sql.tree.Cube;
import io.trino.sql.tree.CurrentCatalog;
import io.trino.sql.tree.CurrentPath;
import io.trino.sql.tree.CurrentSchema;
Expand Down Expand Up @@ -187,7 +186,6 @@
import io.trino.sql.tree.Revoke;
import io.trino.sql.tree.RevokeRoles;
import io.trino.sql.tree.Rollback;
import io.trino.sql.tree.Rollup;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.RowDataType;
import io.trino.sql.tree.RowPattern;
Expand Down Expand Up @@ -278,6 +276,9 @@
import static io.trino.sql.parser.SqlBaseParser.TIMESTAMP;
import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_END;
import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_START;
import static io.trino.sql.tree.GroupingSets.Type.CUBE;
import static io.trino.sql.tree.GroupingSets.Type.EXPLICIT;
import static io.trino.sql.tree.GroupingSets.Type.ROLLUP;
import static io.trino.sql.tree.JsonExists.ErrorBehavior.ERROR;
import static io.trino.sql.tree.JsonExists.ErrorBehavior.FALSE;
import static io.trino.sql.tree.JsonExists.ErrorBehavior.TRUE;
Expand Down Expand Up @@ -1145,19 +1146,23 @@ public Node visitSingleGroupingSet(SqlBaseParser.SingleGroupingSetContext contex
@Override
public Node visitRollup(SqlBaseParser.RollupContext context)
{
return new Rollup(getLocation(context), visit(context.expression(), Expression.class));
return new GroupingSets(getLocation(context), ROLLUP, context.groupingSet().stream()
.map(groupingSet -> visit(groupingSet.expression(), Expression.class))
.collect(toList()));
}

@Override
public Node visitCube(SqlBaseParser.CubeContext context)
{
return new Cube(getLocation(context), visit(context.expression(), Expression.class));
return new GroupingSets(getLocation(context), CUBE, context.groupingSet().stream()
.map(groupingSet -> visit(groupingSet.expression(), Expression.class))
.collect(toList()));
}

@Override
public Node visitMultipleGroupingSets(SqlBaseParser.MultipleGroupingSetsContext context)
{
return new GroupingSets(getLocation(context), context.groupingSet().stream()
return new GroupingSets(getLocation(context), EXPLICIT, context.groupingSet().stream()
.map(groupingSet -> visit(groupingSet.expression(), Expression.class))
.collect(toList()));
}
Expand Down
10 changes: 0 additions & 10 deletions core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -862,21 +862,11 @@ protected R visitGroupingElement(GroupingElement node, C context)
return visitNode(node, context);
}

protected R visitCube(Cube node, C context)
{
return visitGroupingElement(node, context);
}

protected R visitGroupingSets(GroupingSets node, C context)
{
return visitGroupingElement(node, context);
}

protected R visitRollup(Rollup node, C context)
{
return visitGroupingElement(node, context);
}

protected R visitSimpleGroupBy(SimpleGroupBy node, C context)
{
return visitGroupingElement(node, context);
Expand Down
Loading