diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index 953ec8e4c3e7..d7b94ecea62c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -206,7 +206,7 @@ public LogicalPlanner( this.metadata = plannerContext.getMetadata(); this.typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); - this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, metadata, session); + this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, plannerContext, session); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/StatisticsAggregationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/StatisticsAggregationPlanner.java index f3b34488b627..fc7ac561078f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/StatisticsAggregationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/StatisticsAggregationPlanner.java @@ -21,17 +21,20 @@ import io.trino.operator.aggregation.MaxDataSizeForStats; import io.trino.operator.aggregation.SumDataSizeForStats; import io.trino.spi.TrinoException; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; import io.trino.spi.statistics.ColumnStatisticMetadata; import io.trino.spi.statistics.ColumnStatisticType; import io.trino.spi.statistics.TableStatisticType; import io.trino.spi.statistics.TableStatisticsMetadata; import io.trino.spi.type.Type; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.StatisticAggregations; import io.trino.sql.planner.plan.StatisticAggregationsDescriptor; +import io.trino.sql.tree.Expression; import io.trino.sql.tree.QualifiedName; -import io.trino.sql.tree.SymbolReference; import java.util.List; import java.util.Map; @@ -42,6 +45,7 @@ import static com.google.common.base.Verify.verifyNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.MoreCollectors.onlyElement; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.statistics.TableStatisticType.ROW_COUNT; import static io.trino.spi.type.BigintType.BIGINT; @@ -52,13 +56,17 @@ public class StatisticsAggregationPlanner { private final SymbolAllocator symbolAllocator; + private final PlannerContext plannerContext; private final Metadata metadata; + private final LiteralEncoder literalEncoder; private final Session session; - public StatisticsAggregationPlanner(SymbolAllocator symbolAllocator, Metadata metadata, Session session) + public StatisticsAggregationPlanner(SymbolAllocator symbolAllocator, PlannerContext plannerContext, Session session) { this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); - this.metadata = requireNonNull(metadata, "metadata is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.metadata = plannerContext.getMetadata(); + this.literalEncoder = new LiteralEncoder(plannerContext); this.session = requireNonNull(session, "session is null"); } @@ -94,23 +102,22 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta for (ColumnStatisticMetadata columnStatisticMetadata : statisticsMetadata.getColumnStatistics()) { String columnName = columnStatisticMetadata.getColumnName(); + String connectorAggregationId = columnStatisticMetadata.getConnectorAggregationId(); Symbol inputSymbol = columnToSymbolMap.get(columnName); verifyNotNull(inputSymbol, "inputSymbol is null"); Type inputType = symbolAllocator.getTypes().get(inputSymbol); verifyNotNull(inputType, "inputType is null for symbol: %s", inputSymbol); ColumnStatisticsAggregation aggregation; - String symbolHint; if (columnStatisticMetadata.getStatisticTypeIfPresent().isPresent()) { ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType(); aggregation = createColumnAggregation(statisticType, inputSymbol, inputType); - symbolHint = statisticType + ":" + columnName; } else { FunctionName aggregationName = columnStatisticMetadata.getAggregation(); - aggregation = createColumnAggregation(aggregationName, inputSymbol, inputType); - symbolHint = aggregationName.getName() + ":" + columnName; + Optional projection = columnStatisticMetadata.getProjection(); + aggregation = createColumnAggregation(aggregationName, inputSymbol, inputType, projection); } - Symbol symbol = symbolAllocator.newSymbol(symbolHint, aggregation.getOutputType()); + Symbol symbol = symbolAllocator.newSymbol(connectorAggregationId + ":" + columnName, aggregation.getOutputType()); aggregations.put(symbol, aggregation.getAggregation()); descriptor.addColumnStatistic(columnStatisticMetadata, symbol); } @@ -122,34 +129,54 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType statisticType, Symbol input, Type inputType) { return switch (statisticType) { - case MIN_VALUE -> createAggregation(QualifiedName.of("min"), input.toSymbolReference(), inputType); - case MAX_VALUE -> createAggregation(QualifiedName.of("max"), input.toSymbolReference(), inputType); - case NUMBER_OF_DISTINCT_VALUES -> createAggregation(QualifiedName.of("approx_distinct"), input.toSymbolReference(), inputType); + case MIN_VALUE -> createAggregation(QualifiedName.of("min"), input, inputType); + case MAX_VALUE -> createAggregation(QualifiedName.of("max"), input, inputType); + case NUMBER_OF_DISTINCT_VALUES -> createAggregation(QualifiedName.of("approx_distinct"), input, inputType); case NUMBER_OF_DISTINCT_VALUES_SUMMARY -> // we use $approx_set here and not approx_set because latter is not defined for all types supported by Trino - createAggregation(QualifiedName.of("$approx_set"), input.toSymbolReference(), inputType); - case NUMBER_OF_NON_NULL_VALUES -> createAggregation(QualifiedName.of("count"), input.toSymbolReference(), inputType); - case NUMBER_OF_TRUE_VALUES -> createAggregation(QualifiedName.of("count_if"), input.toSymbolReference(), BOOLEAN); - case TOTAL_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(SumDataSizeForStats.NAME), input.toSymbolReference(), inputType); - case MAX_VALUE_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(MaxDataSizeForStats.NAME), input.toSymbolReference(), inputType); + createAggregation(QualifiedName.of("$approx_set"), input, inputType); + case NUMBER_OF_NON_NULL_VALUES -> createAggregation(QualifiedName.of("count"), input, inputType); + case NUMBER_OF_TRUE_VALUES -> createAggregation(QualifiedName.of("count_if"), input, BOOLEAN); + case TOTAL_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(SumDataSizeForStats.NAME), input, inputType); + case MAX_VALUE_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(MaxDataSizeForStats.NAME), input, inputType); }; } - private ColumnStatisticsAggregation createColumnAggregation(FunctionName aggregation, Symbol input, Type inputType) + private ColumnStatisticsAggregation createColumnAggregation(FunctionName aggregation, Symbol input, Type inputType, Optional projection) { checkArgument(aggregation.getCatalogSchema().isEmpty(), "Catalog/schema name not supported"); - return createAggregation(QualifiedName.of(aggregation.getName()), input.toSymbolReference(), inputType); + return createAggregation(QualifiedName.of(aggregation.getName()), input, inputType, projection); } - private ColumnStatisticsAggregation createAggregation(QualifiedName functionName, SymbolReference input, Type inputType) + private ColumnStatisticsAggregation createAggregation(QualifiedName functionName, Symbol input, Type inputType) { - ResolvedFunction resolvedFunction = metadata.resolveFunction(session, functionName, fromTypes(inputType)); + return createAggregation(functionName, input, inputType, Optional.empty()); + } + + private ColumnStatisticsAggregation createAggregation(QualifiedName functionName, Symbol input, Type inputType, Optional projection) + { + Expression aggregationInput; + Type aggregationInputType; + if (projection.isEmpty()) { + aggregationInput = input.toSymbolReference(); + aggregationInputType = inputType; + } + else { + Variable inputVariable = ConnectorExpressions.preOrder(projection.get()) + .filter(Variable.class::isInstance) + .map(Variable.class::cast) + .collect(onlyElement()); + verify(inputVariable.getType().equals(inputType), "Projection variable type %s does not match column type %s", inputVariable.getType(), inputType); + aggregationInput = ConnectorExpressionTranslator.translate(session, projection.get(), plannerContext, ImmutableMap.of(inputVariable.getName(), input), literalEncoder); + aggregationInputType = projection.get().getType(); + } + ResolvedFunction resolvedFunction = metadata.resolveFunction(session, functionName, fromTypes(aggregationInputType)); Type resolvedType = getOnlyElement(resolvedFunction.getSignature().getArgumentTypes()); - verify(resolvedType.equals(inputType), "resolved function input type does not match the input type: %s != %s", resolvedType, inputType); + verify(resolvedType.equals(aggregationInputType), "resolved function input type does not match the input type: %s != %s", resolvedType, aggregationInputType); return new ColumnStatisticsAggregation( new AggregationNode.Aggregation( resolvedFunction, - ImmutableList.of(input), + ImmutableList.of(aggregationInput), false, Optional.empty(), Optional.empty(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestStatisticAggregationsDescriptor.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestStatisticAggregationsDescriptor.java index 42c42859f331..73241ec422bd 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestStatisticAggregationsDescriptor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestStatisticAggregationsDescriptor.java @@ -23,6 +23,8 @@ import io.trino.sql.planner.SymbolAllocator; import org.testng.annotations.Test; +import java.util.Optional; + import static io.trino.spi.statistics.TableStatisticType.ROW_COUNT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.testing.assertions.Assert.assertEquals; @@ -52,8 +54,8 @@ private static StatisticAggregationsDescriptor createTestDescriptor() for (ColumnStatisticType type : ColumnStatisticType.values()) { builder.addColumnStatistic(new ColumnStatisticMetadata(column, type), testSymbol(symbolAllocator)); } - builder.addColumnStatistic(new ColumnStatisticMetadata(column, new FunctionName("count")), testSymbol(symbolAllocator)); - builder.addColumnStatistic(new ColumnStatisticMetadata(column, new FunctionName("count_if")), testSymbol(symbolAllocator)); + builder.addColumnStatistic(new ColumnStatisticMetadata(column, "count non null", new FunctionName("count"), Optional.empty()), testSymbol(symbolAllocator)); + builder.addColumnStatistic(new ColumnStatisticMetadata(column, "count true", new FunctionName("count_if"), Optional.empty()), testSymbol(symbolAllocator)); builder.addGrouping(column, testSymbol(symbolAllocator)); } builder.addTableStatistic(ROW_COUNT, testSymbol(symbolAllocator)); diff --git a/core/trino-spi/src/main/java/io/trino/spi/statistics/ColumnStatisticMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/statistics/ColumnStatisticMetadata.java index 8062ff788312..462e49185929 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/statistics/ColumnStatisticMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/statistics/ColumnStatisticMetadata.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.Experimental; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FunctionName; import java.util.Objects; @@ -28,45 +29,57 @@ public class ColumnStatisticMetadata { private final String columnName; + private final String connectorAggregationId; private final Optional statisticType; private final Optional aggregation; + private final Optional projection; public ColumnStatisticMetadata( String columnName, ColumnStatisticType statisticType) { - this(columnName, Optional.of(statisticType), Optional.empty()); + this(columnName, statisticType.name(), Optional.of(statisticType), Optional.empty(), Optional.empty()); } @Experimental(eta = "2023-01-31") public ColumnStatisticMetadata( String columnName, - FunctionName aggregation) + String connectorAggregationId, + FunctionName aggregation, + Optional projection) { - this(columnName, Optional.empty(), Optional.of(aggregation)); + this(columnName, connectorAggregationId, Optional.empty(), Optional.of(aggregation), projection); } private ColumnStatisticMetadata( String columnName, + String connectorAggregationId, Optional statisticType, - Optional aggregation) + Optional aggregation, + Optional projection) { this.columnName = requireNonNull(columnName, "columnName is null"); + this.connectorAggregationId = requireNonNull(connectorAggregationId, "connectorAggregationId is null"); this.statisticType = requireNonNull(statisticType, "statisticType is null"); this.aggregation = requireNonNull(aggregation, "aggregation is null"); + this.projection = requireNonNull(projection, "projection is null"); if (statisticType.isPresent() == aggregation.isPresent()) { throw new IllegalArgumentException("Exactly one of statisticType and aggregation should be set"); } + if (projection.isPresent() && aggregation.isEmpty()) { + throw new IllegalArgumentException("Projection can only be present when aggregation is"); + } } @Deprecated // For JSON deserialization only @JsonCreator public static ColumnStatisticMetadata fromJson( @JsonProperty("columnName") String columnName, + @JsonProperty("connectorAggregationId") String connectorAggregationId, @JsonProperty("statisticType") Optional statisticType, @JsonProperty("aggregation") Optional aggregation) { - return new ColumnStatisticMetadata(columnName, statisticType, aggregation); + return new ColumnStatisticMetadata(columnName, connectorAggregationId, statisticType, aggregation, Optional.empty()); } @JsonProperty @@ -75,6 +88,13 @@ public String getColumnName() return columnName; } + @Experimental(eta = "2023-01-31") + @JsonProperty + public String getConnectorAggregationId() + { + return connectorAggregationId; + } + @JsonIgnore public ColumnStatisticType getStatisticType() { @@ -102,6 +122,12 @@ public Optional getAggregationIfPresent() return aggregation; } + @JsonIgnore // not needed on workers + public Optional getProjection() + { + return projection; + } + @Override public boolean equals(Object o) { @@ -113,14 +139,16 @@ public boolean equals(Object o) } ColumnStatisticMetadata that = (ColumnStatisticMetadata) o; return Objects.equals(columnName, that.columnName) && + Objects.equals(connectorAggregationId, that.connectorAggregationId) && Objects.equals(statisticType, that.statisticType) && - Objects.equals(aggregation, that.aggregation); + Objects.equals(aggregation, that.aggregation) && + Objects.equals(projection, that.projection); } @Override public int hashCode() { - return Objects.hash(columnName, statisticType, aggregation); + return Objects.hash(columnName, connectorAggregationId, statisticType, aggregation, projection); } @Override @@ -128,8 +156,10 @@ public String toString() { return new StringJoiner(", ", ColumnStatisticMetadata.class.getSimpleName() + "[", "]") .add("columnName='" + columnName + "'") + .add("connectorAggregationId='" + connectorAggregationId + "'") .add("statisticType=" + statisticType) .add("aggregation=" + aggregation) + .add("projection=" + projection) .toString(); } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index e55e2126175a..0a366ba9af0e 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -1490,7 +1490,7 @@ public ConnectorAnalyzeMetadata getStatisticsCollectionMetadata(ConnectorSession Set columnStatistics = tableMetadata.getColumns().stream() .filter(column -> analyzeColumnNames.contains(column.getName())) // TODO: add support for NDV summary/sketch, but using Theta sketch, not HLL; see https://github.com/apache/iceberg-docs/pull/69 - .map(column -> new ColumnStatisticMetadata(column.getName(), NUMBER_OF_DISTINCT_VALUES)) + .map(column -> new ColumnStatisticMetadata(column.getName(), "ndv", NUMBER_OF_DISTINCT_VALUES, Optional.empty())) .collect(toImmutableSet()); return new ConnectorAnalyzeMetadata(