diff --git a/plugin/trino-iceberg/pom.xml b/plugin/trino-iceberg/pom.xml index 44abac67385c..3d750bb228e8 100644 --- a/plugin/trino-iceberg/pom.xml +++ b/plugin/trino-iceberg/pom.xml @@ -40,6 +40,12 @@ jackson-databind + + com.google.code.findbugs + jsr305 + true + + com.google.errorprone error_prone_annotations @@ -142,6 +148,11 @@ trino-hive-formats + + io.trino + trino-matching + + io.trino trino-memory-context diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java index e1a63f78c32c..074272e7a128 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java @@ -32,6 +32,7 @@ import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.plugin.iceberg.IcebergMetadataColumn.FILE_MODIFIED_TIME; import static io.trino.plugin.iceberg.IcebergMetadataColumn.FILE_PATH; +import static io.trino.plugin.iceberg.aggregation.AggregateExpression.COUNT_AGGREGATE_COLUMN_ID; import static java.util.Objects.requireNonNull; import static org.apache.iceberg.MetadataColumns.IS_DELETED; import static org.apache.iceberg.MetadataColumns.ROW_POSITION; @@ -299,4 +300,9 @@ public boolean isPathColumn() { return getColumnIdentity().getId() == FILE_PATH.getId(); } + + public boolean isAggregateColumn() + { + return getColumnIdentity().getId() == COUNT_AGGREGATE_COLUMN_ID; + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java index 2dc0615d2f44..503e9975eeda 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConfig.java @@ -79,6 +79,7 @@ public class IcebergConfig private boolean sortedWritingEnabled = true; private boolean queryPartitionFilterRequired; private int splitManagerThreads = Runtime.getRuntime().availableProcessors() * 2; + private boolean aggregationPushdownEnabled; public CatalogType getCatalogType() { @@ -435,4 +436,16 @@ public boolean isStorageSchemaSetWhenHidingIsEnabled() { return hideMaterializedViewStorageTable && materializedViewsStorageSchema.isPresent(); } + + public boolean isAggregationPushdownEnabled() + { + return aggregationPushdownEnabled; + } + + @Config("iceberg.aggregation-pushdown.enabled") + public IcebergConfig setAggregationPushdownEnabled(boolean aggregationPushdownEnabled) + { + this.aggregationPushdownEnabled = aggregationPushdownEnabled; + return this; + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index 1a9ea9571ac8..3886afdcf120 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -32,14 +32,19 @@ import io.trino.filesystem.FileIterator; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoFileSystem; +import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.base.classloader.ClassLoaderSafeSystemTable; +import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.base.filter.UtcConstraintExtractor; import io.trino.plugin.base.projection.ApplyProjectionUtil; import io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation; import io.trino.plugin.hive.HiveWrittenPartitions; import io.trino.plugin.hive.metastore.TableInfo; +import io.trino.plugin.iceberg.aggregation.AggregateExpression; import io.trino.plugin.iceberg.aggregation.DataSketchStateSerializer; import io.trino.plugin.iceberg.aggregation.IcebergThetaSketchForStats; +import io.trino.plugin.iceberg.aggregation.ImplementCountAll; import io.trino.plugin.iceberg.catalog.TrinoCatalog; import io.trino.plugin.iceberg.procedure.IcebergDropExtendedStatsHandle; import io.trino.plugin.iceberg.procedure.IcebergExpireSnapshotsHandle; @@ -51,6 +56,8 @@ import io.trino.spi.ErrorCode; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.connector.AggregationApplicationResult; import io.trino.spi.connector.Assignment; import io.trino.spi.connector.BeginTableExecuteResult; import io.trino.spi.connector.CatalogHandle; @@ -135,6 +142,7 @@ import org.apache.iceberg.SchemaParser; import org.apache.iceberg.Snapshot; import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.SortField; import org.apache.iceberg.SortOrder; import org.apache.iceberg.StatisticsFile; @@ -220,6 +228,7 @@ import static io.trino.plugin.iceberg.IcebergSessionProperties.getExpireSnapshotMinRetention; import static io.trino.plugin.iceberg.IcebergSessionProperties.getHiveCatalogName; import static io.trino.plugin.iceberg.IcebergSessionProperties.getRemoveOrphanFilesMinRetention; +import static io.trino.plugin.iceberg.IcebergSessionProperties.isAggregationPushdownEnabled; import static io.trino.plugin.iceberg.IcebergSessionProperties.isCollectExtendedStatisticsOnWrite; import static io.trino.plugin.iceberg.IcebergSessionProperties.isExtendedStatisticsEnabled; import static io.trino.plugin.iceberg.IcebergSessionProperties.isMergeManifestsOnWrite; @@ -327,6 +336,7 @@ public class IcebergMetadata private final TrinoCatalog catalog; private final IcebergFileSystemFactory fileSystemFactory; private final TableStatisticsWriter tableStatisticsWriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; private final Map tableStatisticsCache = new ConcurrentHashMap<>(); @@ -346,6 +356,12 @@ public IcebergMetadata( this.catalog = requireNonNull(catalog, "catalog is null"); this.fileSystemFactory = requireNonNull(fileSystemFactory, "fileSystemFactory is null"); this.tableStatisticsWriter = requireNonNull(tableStatisticsWriter, "tableStatisticsWriter is null"); + + this.aggregateFunctionRewriter = new AggregateFunctionRewriter( + new ConnectorExpressionRewriter<>(ImmutableSet.of()), + ImmutableSet.>builder() + .add(new ImplementCountAll()) + .build()); } @Override @@ -2653,6 +2669,101 @@ else if (isMetadataColumnId(columnHandle.getId())) { false)); } + @Override + public Optional> applyAggregation( + ConnectorSession session, + ConnectorTableHandle handle, + List aggregates, + Map assignments, + List> groupingSets) + { + IcebergTableHandle tableHandle = (IcebergTableHandle) handle; + + // Iceberg's metadata cannot be used for aggregation calculation. + // As equality deletes do not reflect at the metadata/count level. + if (hasEqualityDeletes(session, tableHandle)) { + return Optional.empty(); + } + + if (!isAggregationPushdownEnabled(session)) { + return Optional.empty(); + } + + // not supporting unenforced predicate + if (!tableHandle.getUnenforcedPredicate().isNone() + && tableHandle.getUnenforcedPredicate().getDomains().isPresent() + && !tableHandle.getUnenforcedPredicate().getDomains().get().isEmpty()) { + return Optional.empty(); + } + + // not supporting group by + if (!groupingSets.equals(List.of(List.of()))) { + return Optional.empty(); + } + + ImmutableList.Builder projections = ImmutableList.builder(); + ImmutableList.Builder resultAssignments = ImmutableList.builder(); + ImmutableList.Builder aggregateColumnsBuilder = ImmutableList.builder(); + + Set projectionsSet = new HashSet<>(); + + if (aggregates.size() != 1) { + // not handling multiple aggregations for now + return Optional.empty(); + } + + AggregateFunction aggregate = aggregates.get(0); + + Optional rewriteResult = aggregateFunctionRewriter.rewrite(session, aggregate, assignments); + if (rewriteResult.isEmpty()) { + return Optional.empty(); + } + AggregateExpression aggregateExpression = rewriteResult.get(); + + if (aggregateExpression.getFunction().startsWith("count")) { + IcebergColumnHandle aggregateIcebergColumnHandle = new IcebergColumnHandle(aggregateExpression.toColumnIdentity(AggregateExpression.COUNT_AGGREGATE_COLUMN_ID), + aggregate.getOutputType(), List.of(), aggregate.getOutputType(), false, Optional.empty()); + aggregateColumnsBuilder.add(aggregateIcebergColumnHandle); + projections.add(new Variable(aggregateIcebergColumnHandle.getName(), aggregateIcebergColumnHandle.getType())); + projectionsSet.add(aggregateIcebergColumnHandle); + resultAssignments.add(new Assignment(aggregateIcebergColumnHandle.getName(), aggregateIcebergColumnHandle, aggregateIcebergColumnHandle.getType())); + } + + IcebergTableHandle tableHandleTemp = new IcebergTableHandle( + tableHandle.getCatalog(), + tableHandle.getSchemaName(), + tableHandle.getTableName(), + tableHandle.getTableType(), + tableHandle.getSnapshotId(), + tableHandle.getTableSchemaJson(), + tableHandle.getPartitionSpecJson(), + tableHandle.getFormatVersion(), + tableHandle.getUnenforcedPredicate(), + tableHandle.getEnforcedPredicate(), + tableHandle.getLimit(), + projectionsSet, + tableHandle.getNameMappingJson(), + tableHandle.getTableLocation(), + tableHandle.getStorageProperties(), + tableHandle.isRecordScannedFiles(), + tableHandle.getMaxScannedFileSize(), + tableHandle.getConstraintColumns(), + tableHandle.getForAnalyze()); + + return Optional.of(new AggregationApplicationResult<>(tableHandleTemp, projections.build(), resultAssignments.build(), ImmutableMap.of(), false)); + } + + private boolean hasEqualityDeletes(ConnectorSession session, IcebergTableHandle tableHandle) + { + Table icebergTable = catalog.loadTable(session, tableHandle.getSchemaTableName()); + + if (icebergTable.currentSnapshot().summary().containsKey(SnapshotSummary.TOTAL_EQ_DELETES_PROP)) { + return (Long.parseLong(icebergTable.currentSnapshot().summary().get(SnapshotSummary.TOTAL_EQ_DELETES_PROP)) > 0); + } + + return false; + } + private static Set identityPartitionColumnsInAllSpecs(Table table) { // Extract identity partition column source ids common to ALL specs diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index ce363f84f6c5..2fd43a09a14b 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -57,6 +57,8 @@ import io.trino.plugin.hive.parquet.ParquetPageSource; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.iceberg.IcebergParquetColumnIOConverter.FieldContext; +import io.trino.plugin.iceberg.aggregation.AggregateIcebergSplit; +import io.trino.plugin.iceberg.aggregation.AggregatePageSource; import io.trino.plugin.iceberg.delete.DeleteFile; import io.trino.plugin.iceberg.delete.DeleteFilter; import io.trino.plugin.iceberg.delete.EqualityDeleteFilter; @@ -251,36 +253,43 @@ public ConnectorPageSource createPageSource( List columns, DynamicFilter dynamicFilter) { - IcebergSplit split = (IcebergSplit) connectorSplit; List icebergColumns = columns.stream() .map(IcebergColumnHandle.class::cast) .collect(toImmutableList()); - IcebergTableHandle tableHandle = (IcebergTableHandle) connectorTable; - Schema schema = SchemaParser.fromJson(tableHandle.getTableSchemaJson()); - PartitionSpec partitionSpec = PartitionSpecParser.fromJson(schema, split.getPartitionSpecJson()); - org.apache.iceberg.types.Type[] partitionColumnTypes = partitionSpec.fields().stream() - .map(field -> field.transform().getResultType(schema.findType(field.sourceId()))) - .toArray(org.apache.iceberg.types.Type[]::new); - - return createPageSource( - session, - icebergColumns, - schema, - partitionSpec, - PartitionData.fromJson(split.getPartitionDataJson(), partitionColumnTypes), - split.getDeletes(), - dynamicFilter, - tableHandle.getUnenforcedPredicate(), - split.getFileStatisticsDomain(), - split.getPath(), - split.getStart(), - split.getLength(), - split.getFileSize(), - split.getFileRecordCount(), - split.getPartitionDataJson(), - split.getFileFormat(), - split.getFileIoProperties(), - tableHandle.getNameMappingJson().map(NameMappingParser::fromJson)); + + if (shouldHandleAggregatePushDown(icebergColumns)) { + AggregateIcebergSplit aggregateIcebergSplit = (AggregateIcebergSplit) connectorSplit; + return new AggregatePageSource(icebergColumns, aggregateIcebergSplit.getTotalCount()); + } + else { + IcebergSplit split = (IcebergSplit) connectorSplit; + IcebergTableHandle tableHandle = (IcebergTableHandle) connectorTable; + Schema schema = SchemaParser.fromJson(tableHandle.getTableSchemaJson()); + PartitionSpec partitionSpec = PartitionSpecParser.fromJson(schema, split.getPartitionSpecJson()); + org.apache.iceberg.types.Type[] partitionColumnTypes = partitionSpec.fields().stream() + .map(field -> field.transform().getResultType(schema.findType(field.sourceId()))) + .toArray(org.apache.iceberg.types.Type[]::new); + + return createPageSource( + session, + icebergColumns, + schema, + partitionSpec, + PartitionData.fromJson(split.getPartitionDataJson(), partitionColumnTypes), + split.getDeletes(), + dynamicFilter, + tableHandle.getUnenforcedPredicate(), + split.getFileStatisticsDomain(), + split.getPath(), + split.getStart(), + split.getLength(), + split.getFileSize(), + split.getFileRecordCount(), + split.getPartitionDataJson(), + split.getFileFormat(), + split.getFileIoProperties(), + tableHandle.getNameMappingJson().map(NameMappingParser::fromJson)); + } } public ConnectorPageSource createPageSource( @@ -1542,6 +1551,11 @@ private static TrinoException handleException(ParquetDataSourceId dataSourceId, return new TrinoException(ICEBERG_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); } + private static boolean shouldHandleAggregatePushDown(List columns) + { + return columns.size() == 1 && columns.get(0).isAggregateColumn(); + } + public static final class ReaderPageSourceWithRowPositions { private final ReaderPageSource readerPageSource; diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java index 5f8501ec23a8..284e780c6b34 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSessionProperties.java @@ -97,6 +97,7 @@ public final class IcebergSessionProperties private static final String MERGE_MANIFESTS_ON_WRITE = "merge_manifests_on_write"; private static final String SORTED_WRITING_ENABLED = "sorted_writing_enabled"; private static final String QUERY_PARTITION_FILTER_REQUIRED = "query_partition_filter_required"; + public static final String AGGREGATION_PUSHDOWN_ENABLED = "aggregation_pushdown_enabled"; private final List> sessionProperties; @@ -348,6 +349,11 @@ public IcebergSessionProperties( "Require filter on partition column", icebergConfig.isQueryPartitionFilterRequired(), false)) + .add(booleanProperty( + AGGREGATION_PUSHDOWN_ENABLED, + "Enable Aggregation Pushdown", + icebergConfig.isAggregationPushdownEnabled(), + false)) .build(); } @@ -568,4 +574,9 @@ public static boolean isQueryPartitionFilterRequired(ConnectorSession session) { return session.getProperty(QUERY_PARTITION_FILTER_REQUIRED, Boolean.class); } + + public static boolean isAggregationPushdownEnabled(ConnectorSession session) + { + return session.getProperty(AGGREGATION_PUSHDOWN_ENABLED, Boolean.class); + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java index fb313676841e..5b0a1d8a0271 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java @@ -18,6 +18,7 @@ import io.airlift.units.Duration; import io.trino.filesystem.cache.CachingHostAddressProvider; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorSplitSource; +import io.trino.plugin.iceberg.aggregation.AggregateSplitSource; import io.trino.plugin.iceberg.functions.tablechanges.TableChangesFunctionHandle; import io.trino.plugin.iceberg.functions.tablechanges.TableChangesSplitSource; import io.trino.spi.connector.ConnectorSession; @@ -35,6 +36,7 @@ import java.util.concurrent.ExecutorService; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.iceberg.IcebergSessionProperties.getDynamicFilteringWaitTimeout; import static io.trino.plugin.iceberg.IcebergSessionProperties.getMinimumAssignedSplitWeight; import static io.trino.spi.connector.FixedSplitSource.emptySplitSource; @@ -90,20 +92,39 @@ public ConnectorSplitSource getSplits( .useSnapshot(table.getSnapshotId().get()) .planWith(executor); - IcebergSplitSource splitSource = new IcebergSplitSource( - fileSystemFactory, - session, - table, - icebergTable.io().properties(), - tableScan, - table.getMaxScannedFileSize(), - dynamicFilter, - dynamicFilteringWaitTimeout, - constraint, - typeManager, - table.isRecordScannedFiles(), - getMinimumAssignedSplitWeight(session), - cachingHostAddressProvider); + ConnectorSplitSource splitSource = null; + + if (shouldHandleAggregatePushDown(table)) { + splitSource = new AggregateSplitSource( + fileSystemFactory, + session, + table, + icebergTable.io().properties(), + tableScan, + table.getMaxScannedFileSize(), + dynamicFilter, + dynamicFilteringWaitTimeout, + constraint, + typeManager, + table.isRecordScannedFiles(), + getMinimumAssignedSplitWeight(session)); + } + else { + splitSource = new IcebergSplitSource( + fileSystemFactory, + session, + table, + icebergTable.io().properties(), + tableScan, + table.getMaxScannedFileSize(), + dynamicFilter, + dynamicFilteringWaitTimeout, + constraint, + typeManager, + table.isRecordScannedFiles(), + getMinimumAssignedSplitWeight(session), + cachingHostAddressProvider); + } return new ClassLoaderSafeConnectorSplitSource(splitSource, IcebergSplitManager.class.getClassLoader()); } @@ -127,4 +148,15 @@ public ConnectorSplitSource getSplits( throw new IllegalStateException("Unknown table function: " + function); } + + private static boolean shouldHandleAggregatePushDown(IcebergTableHandle icebergTableHandle) + { + if (icebergTableHandle.getProjectedColumns().size() != 1) { + return false; + } + + return icebergTableHandle.getProjectedColumns().stream().filter(stringColumnHandleEntry -> { + return stringColumnHandleEntry.isAggregateColumn(); + }).collect(toImmutableList()).size() > 0; + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateExpression.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateExpression.java new file mode 100644 index 000000000000..9d6cd9050c50 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateExpression.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg.aggregation; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.trino.plugin.iceberg.ColumnIdentity; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static io.trino.plugin.iceberg.IcebergMetadataColumn.FILE_MODIFIED_TIME; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class AggregateExpression +{ + private final String function; + private final String argument; + // did not find a cleaner way to find an id for aggregate's synthetic column, + // Trino maintains FILE_MODIFIED_TIME as one of the synthetic metadata column. + // which is Integer.MAX_VALUE - 1001, so using FILE_MODIFIED_TIME - 1001 to have some buffer. + public static final Integer COUNT_AGGREGATE_COLUMN_ID = FILE_MODIFIED_TIME.getId() - 1001; + + @JsonCreator + public AggregateExpression(@JsonProperty String function, @JsonProperty String argument) + { + this.function = requireNonNull(function, "function is null"); + this.argument = requireNonNull(argument, "argument is null"); + } + + @JsonProperty + public String getFunction() + { + return function; + } + + @JsonProperty + public String getArgument() + { + return argument; + } + + public ColumnIdentity toColumnIdentity(Integer columnId) + { + return new ColumnIdentity(columnId, format("%s(%s)", function, argument), ColumnIdentity.TypeCategory.PRIMITIVE, ImmutableList.of()); + } + + @Override + public boolean equals(Object other) + { + if (this == other) { + return true; + } + if (!(other instanceof AggregateExpression)) { + return false; + } + AggregateExpression that = (AggregateExpression) other; + return that.function.equals(function) && + that.argument.equals(argument); + } + + @Override + public int hashCode() + { + return Objects.hash(function, argument); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("function", function) + .add("argument", argument) + .toString(); + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateIcebergSplit.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateIcebergSplit.java new file mode 100644 index 000000000000..9a7afd80f345 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateIcebergSplit.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg.aggregation; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.iceberg.IcebergSplit; +import io.trino.spi.HostAddress; +import io.trino.spi.connector.ConnectorSplit; + +import java.util.List; +import java.util.OptionalLong; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; + +public class AggregateIcebergSplit + implements ConnectorSplit +{ + private static final int INSTANCE_SIZE = instanceSize(IcebergSplit.class); + private final List addresses; + private final long totalCount; + + @JsonCreator + public AggregateIcebergSplit( + @JsonProperty("addresses") List addresses, + @JsonProperty("totalCount") long totalCount) + { + this.addresses = addresses; + this.totalCount = totalCount; + } + + @JsonProperty + public long getTotalCount() + { + return totalCount; + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + sizeOf(OptionalLong.of(totalCount)); + } + + @Override + + public boolean isRemotelyAccessible() + { + return true; + } + + @JsonProperty + @Override + public List getAddresses() + { + return addresses; + } + + @Override + public Object getInfo() + { + return ImmutableMap.builder() + .put("totalCount", totalCount) + .buildOrThrow(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .addValue(totalCount) + .toString(); + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregatePageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregatePageSource.java new file mode 100644 index 000000000000..fcb1b4a9557f --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregatePageSource.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg.aggregation; + +import io.trino.plugin.iceberg.IcebergColumnHandle; +import io.trino.plugin.iceberg.util.PageListBuilder; +import io.trino.spi.Page; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.metrics.Metrics; +import io.trino.spi.type.Type; + +import java.util.Iterator; +import java.util.List; +import java.util.OptionalLong; +import java.util.concurrent.CompletableFuture; + +import static java.util.stream.Collectors.toList; + +public class AggregatePageSource + implements ConnectorPageSource +{ + private final List columnTypes; + private long readTimeNanos; + private Iterator pages; + private final long recordCount; + + public AggregatePageSource(List columnHandles, long recordCount) + { + // _pos columns are not required. + this.columnTypes = columnHandles.stream().filter(columnHandle -> !columnHandle.isRowPositionColumn()).map(ch -> ch.getType()).collect(toList()); + this.recordCount = recordCount; + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + /** + * Gets the number of input rows processed by this page source so far. + * By default, the positions count of the page returned from getNextPage + * is used to calculate the number of input rows. + */ + @Override + public OptionalLong getCompletedPositions() + { + return ConnectorPageSource.super.getCompletedPositions(); + } + + @Override + public long getReadTimeNanos() + { + return readTimeNanos; + } + + @Override + public boolean isFinished() + { + return pages != null && !pages.hasNext(); + } + + @Override + public Page getNextPage() + { + if (pages != null && pages.hasNext()) { + return pages.next(); + } + + long start = System.nanoTime(); + PageListBuilder pageListBuilder = new PageListBuilder(columnTypes); + + pageListBuilder.beginRow(); + pageListBuilder.appendBigint(recordCount); + pageListBuilder.endRow(); + + this.readTimeNanos += System.nanoTime() - start; + this.pages = pageListBuilder.build().iterator(); + return pages.next(); + } + + /** + * Get the total memory that needs to be reserved in the memory pool. + * This memory should include any buffers, etc. that are used for reading data. + * + * @return the memory used so far in table read + */ + @Override + public long getMemoryUsage() + { + return 0; + } + + @Override + public void close() + { + } + + /** + * Returns a future that will be completed when the page source becomes + * unblocked. If the page source is not blocked, this method should return + * {@code NOT_BLOCKED}. + */ + @Override + public CompletableFuture isBlocked() + { + return ConnectorPageSource.super.isBlocked(); + } + + /** + * Returns the connector's metrics, mapping a metric ID to its latest value. + * Each call must return an immutable snapshot of available metrics. + * Same ID metrics are merged across all tasks and exposed via OperatorStats. + * This method can be called after the page source is closed. + */ + @Override + public Metrics getMetrics() + { + return ConnectorPageSource.super.getMetrics(); + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateSplitSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateSplitSource.java new file mode 100644 index 000000000000..f8c1a9399766 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/AggregateSplitSource.java @@ -0,0 +1,489 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg.aggregation; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Stopwatch; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.io.Closer; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoInputFile; +import io.trino.plugin.iceberg.IcebergColumnHandle; +import io.trino.plugin.iceberg.IcebergFileSystemFactory; +import io.trino.plugin.iceberg.IcebergTableHandle; +import io.trino.plugin.iceberg.util.DataFileWithDeleteFiles; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.connector.ConnectorSplitSource; +import io.trino.spi.connector.Constraint; +import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.NullableValue; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.TypeManager; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.util.TableScanUtil; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Suppliers.memoize; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.intersection; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.plugin.iceberg.ExpressionConverter.toIcebergExpression; +import static io.trino.plugin.iceberg.IcebergColumnHandle.fileModifiedTimeColumnHandle; +import static io.trino.plugin.iceberg.IcebergColumnHandle.pathColumnHandle; +import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_FILESYSTEM_ERROR; +import static io.trino.plugin.iceberg.IcebergMetadataColumn.isMetadataColumnId; +import static io.trino.plugin.iceberg.IcebergSplitManager.ICEBERG_DOMAIN_COMPACTION_THRESHOLD; +import static io.trino.plugin.iceberg.IcebergTypes.convertIcebergValueToTrino; +import static io.trino.plugin.iceberg.IcebergUtil.deserializePartitionValue; +import static io.trino.plugin.iceberg.IcebergUtil.getColumnHandle; +import static io.trino.plugin.iceberg.IcebergUtil.getPartitionKeys; +import static io.trino.plugin.iceberg.IcebergUtil.primitiveFieldTypes; +import static io.trino.plugin.iceberg.TypeConverter.toIcebergType; +import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.apache.iceberg.types.Conversions.fromByteBuffer; + +public class AggregateSplitSource + implements ConnectorSplitSource +{ + private static final ConnectorSplitBatch EMPTY_BATCH = new ConnectorSplitBatch(ImmutableList.of(), false); + private static final ConnectorSplitBatch NO_MORE_SPLITS_BATCH = new ConnectorSplitBatch(ImmutableList.of(), true); + + private final IcebergFileSystemFactory fileSystemFactory; + protected final ConnectorSession session; + protected final IcebergTableHandle tableHandle; + protected final TableScan tableScan; + protected final Optional maxScannedFileSizeInBytes; + protected final Map fieldIdToType; + protected final DynamicFilter dynamicFilter; + protected final long dynamicFilteringWaitTimeoutMillis; + protected final Stopwatch dynamicFilterWaitStopwatch; + protected final Constraint constraint; + protected final TypeManager typeManager; + protected final Closer closer = Closer.create(); + protected final double minimumAssignedSplitWeight; + protected final TupleDomain dataColumnPredicate; + protected final Domain pathDomain; + protected final Domain fileModifiedTimeDomain; + + private CloseableIterable fileScanTaskIterable; + private CloseableIterator fileScanTaskIterator; + protected TupleDomain pushedDownDynamicFilterPredicate; + + protected final boolean recordScannedFiles; + private final ImmutableSet.Builder scannedFiles = ImmutableSet.builder(); + + private final Map fileIoProperties; + + public AggregateSplitSource( + IcebergFileSystemFactory fileSystemFactory, + ConnectorSession session, + IcebergTableHandle tableHandle, + Map fileIoProperties, + TableScan tableScan, + Optional maxScannedFileSize, + DynamicFilter dynamicFilter, + Duration dynamicFilteringWaitTimeout, + Constraint constraint, + TypeManager typeManager, + boolean recordScannedFiles, + double minimumAssignedSplitWeight) + { + this.fileSystemFactory = requireNonNull(fileSystemFactory, "fileSystemFactory is null"); + this.session = requireNonNull(session, "session is null"); + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.fileIoProperties = fileIoProperties; + this.tableScan = requireNonNull(tableScan, "tableScan is null"); + this.maxScannedFileSizeInBytes = maxScannedFileSize.map(DataSize::toBytes); + this.fieldIdToType = primitiveFieldTypes(tableScan.schema()); + this.dynamicFilter = requireNonNull(dynamicFilter, "dynamicFilter is null"); + this.dynamicFilteringWaitTimeoutMillis = dynamicFilteringWaitTimeout.toMillis(); + this.dynamicFilterWaitStopwatch = Stopwatch.createStarted(); + this.constraint = requireNonNull(constraint, "constraint is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.recordScannedFiles = recordScannedFiles; + this.minimumAssignedSplitWeight = minimumAssignedSplitWeight; + this.dataColumnPredicate = tableHandle.getEnforcedPredicate().filter((column, domain) -> !isMetadataColumnId(column.getId())); + this.pathDomain = getPathDomain(tableHandle.getEnforcedPredicate()); + this.fileModifiedTimeDomain = getFileModifiedTimePathDomain(tableHandle.getEnforcedPredicate()); + } + + @Override + public CompletableFuture getNextBatch(int maxSize) + { + long timeLeft = dynamicFilteringWaitTimeoutMillis - dynamicFilterWaitStopwatch.elapsed(MILLISECONDS); + if (dynamicFilter.isAwaitable() && timeLeft > 0) { + return dynamicFilter.isBlocked() + .thenApply(ignored -> EMPTY_BATCH) + .completeOnTimeout(EMPTY_BATCH, timeLeft, MILLISECONDS); + } + + long recordCount = 0L; + Set uniqueFileName = new HashSet<>(); + + if (fileScanTaskIterable == null) { + // Used to avoid duplicating work if the Dynamic Filter was already pushed down to the Iceberg API + boolean dynamicFilterIsComplete = dynamicFilter.isComplete(); + this.pushedDownDynamicFilterPredicate = dynamicFilter.getCurrentPredicate().transformKeys(IcebergColumnHandle.class::cast); + TupleDomain fullPredicate = tableHandle.getUnenforcedPredicate() + .intersect(pushedDownDynamicFilterPredicate); + // TODO: (https://github.com/trinodb/trino/issues/9743): Consider removing TupleDomain#simplify + TupleDomain simplifiedPredicate = fullPredicate.simplify(ICEBERG_DOMAIN_COMPACTION_THRESHOLD); + boolean usedSimplifiedPredicate = !simplifiedPredicate.equals(fullPredicate); + if (usedSimplifiedPredicate) { + // Pushed down predicate was simplified, always evaluate it against individual splits + this.pushedDownDynamicFilterPredicate = TupleDomain.all(); + } + + TupleDomain effectivePredicate = dataColumnPredicate + .intersect(simplifiedPredicate); + + if (effectivePredicate.isNone()) { + finish(); + return completedFuture(NO_MORE_SPLITS_BATCH); + } + + Expression filterExpression = toIcebergExpression(effectivePredicate); + // If the Dynamic Filter will be evaluated against each file, stats are required. Otherwise, skip them. + boolean requiresColumnStats = usedSimplifiedPredicate || !dynamicFilterIsComplete; + TableScan scan = tableScan.filter(filterExpression); + if (requiresColumnStats) { + scan = scan.includeColumnStats(); + } + this.fileScanTaskIterable = TableScanUtil.splitFiles(scan.planFiles(), tableScan.targetSplitSize()); + closer.register(fileScanTaskIterable); + this.fileScanTaskIterator = fileScanTaskIterable.iterator(); + closer.register(fileScanTaskIterator); + // TODO: Remove when NPE check has been released: https://github.com/trinodb/trino/issues/15372 + isFinished(); + } + + TupleDomain dynamicFilterPredicate = dynamicFilter.getCurrentPredicate() + .transformKeys(IcebergColumnHandle.class::cast); + if (dynamicFilterPredicate.isNone()) { + finish(); + return completedFuture(NO_MORE_SPLITS_BATCH); + } + + // Note: not using the maxSize of the batch, as we need to calculate the aggregates in single threaded. + // When we will implement it in distributed way which might need the Engine to support Top level aggregation. + // As right now if connector handles aggregation pushdown, Engine does not handle the aggregate function. + ImmutableList.Builder splits = ImmutableList.builder(); + while (fileScanTaskIterator.hasNext()) { + FileScanTask scanTask = fileScanTaskIterator.next(); + + if (!uniqueFileName.add(scanTask.file().path().toString())) { + // duplicate file entry, mostly because of the file size and Iceberg's internal default file-size=128 MB + // so do not need to process this entry, as this data file's metadata is already being used. + continue; + } + + if (scanTask.deletes().isEmpty() && + maxScannedFileSizeInBytes.isPresent() && + scanTask.file().fileSizeInBytes() > maxScannedFileSizeInBytes.get()) { + continue; + } + + if (!pathDomain.includesNullableValue(utf8Slice(scanTask.file().path().toString()))) { + continue; + } + if (!fileModifiedTimeDomain.isAll()) { + long fileModifiedTime = getModificationTime(scanTask.file().path().toString()); + if (!fileModifiedTimeDomain.includesNullableValue(packDateTimeWithZone(fileModifiedTime, UTC_KEY))) { + continue; + } + } + + Schema fileSchema = scanTask.spec().schema(); + Map> partitionKeys = getPartitionKeys(scanTask); + + Set identityPartitionColumns = partitionKeys.keySet().stream() + .map(fieldId -> getColumnHandle(fileSchema.findField(fieldId), typeManager)) + .collect(toImmutableSet()); + + Supplier> partitionValues = memoize(() -> { + Map bindings = new HashMap<>(); + for (IcebergColumnHandle partitionColumn : identityPartitionColumns) { + Object partitionValue = deserializePartitionValue( + partitionColumn.getType(), + partitionKeys.get(partitionColumn.getId()).orElse(null), + partitionColumn.getName()); + NullableValue bindingValue = new NullableValue(partitionColumn.getType(), partitionValue); + bindings.put(partitionColumn, bindingValue); + } + return bindings; + }); + + if (!dynamicFilterPredicate.isAll() && !dynamicFilterPredicate.equals(pushedDownDynamicFilterPredicate)) { + if (!partitionMatchesPredicate( + identityPartitionColumns, + partitionValues, + dynamicFilterPredicate)) { + continue; + } + if (!fileMatchesPredicate( + fieldIdToType, + dynamicFilterPredicate, + scanTask.file().lowerBounds(), + scanTask.file().upperBounds(), + scanTask.file().nullValueCounts())) { + continue; + } + } + if (!partitionMatchesConstraint(identityPartitionColumns, partitionValues, constraint)) { + continue; + } + + if (recordScannedFiles) { + // Positional and Equality deletes can only be cleaned up if the whole table has been optimized. + // Equality deletes may apply to many files, and position deletes may be grouped together. This makes it difficult to know if they are obsolete. + List fullyAppliedDeletes = tableHandle.getEnforcedPredicate().isAll() ? scanTask.deletes() : ImmutableList.of(); + scannedFiles.add(new DataFileWithDeleteFiles(scanTask.file(), fullyAppliedDeletes)); + } + + recordCount = recordCount + scanTask.file().recordCount() - getDeletedFilesCount(scanTask); + } + + if (recordCount >= 0) { + AggregateIcebergSplit icebergSplit = toIcebergSplit(recordCount); + splits.add(icebergSplit); + } + + return completedFuture(new ConnectorSplitBatch(splits.build(), isFinished())); + } + + private long getModificationTime(String path) + { + try { + TrinoInputFile inputFile = fileSystemFactory.create(session.getIdentity(), fileIoProperties).newInputFile(Location.of(path)); + return inputFile.lastModified().toEpochMilli(); + } + catch (IOException e) { + throw new TrinoException(ICEBERG_FILESYSTEM_ERROR, "Failed to get file modification time: " + path, e); + } + } + + private void finish() + { + close(); + this.fileScanTaskIterable = CloseableIterable.empty(); + this.fileScanTaskIterator = CloseableIterator.empty(); + } + + @Override + public boolean isFinished() + { + return fileScanTaskIterator != null && !fileScanTaskIterator.hasNext(); + } + + @Override + public Optional> getTableExecuteSplitsInfo() + { + checkState(isFinished(), "Split source must be finished before TableExecuteSplitsInfo is read"); + if (!recordScannedFiles) { + return Optional.empty(); + } + return Optional.of(ImmutableList.copyOf(scannedFiles.build())); + } + + @Override + public void close() + { + try { + closer.close(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @VisibleForTesting + static boolean fileMatchesPredicate( + Map primitiveTypeForFieldId, + TupleDomain dynamicFilterPredicate, + @Nullable Map lowerBounds, + @Nullable Map upperBounds, + @Nullable Map nullValueCounts) + { + if (dynamicFilterPredicate.isNone()) { + return false; + } + Map domains = dynamicFilterPredicate.getDomains().orElseThrow(); + + for (Map.Entry domainEntry : domains.entrySet()) { + IcebergColumnHandle column = domainEntry.getKey(); + Domain domain = domainEntry.getValue(); + + int fieldId = column.getId(); + boolean mayContainNulls; + if (nullValueCounts == null) { + mayContainNulls = true; + } + else { + Long nullValueCount = nullValueCounts.get(fieldId); + mayContainNulls = nullValueCount == null || nullValueCount > 0; + } + Type type = primitiveTypeForFieldId.get(fieldId); + Domain statisticsDomain = domainForStatistics( + column, + lowerBounds == null ? null : fromByteBuffer(type, lowerBounds.get(fieldId)), + upperBounds == null ? null : fromByteBuffer(type, upperBounds.get(fieldId)), + mayContainNulls); + if (!domain.overlaps(statisticsDomain)) { + return false; + } + } + return true; + } + + private static Domain domainForStatistics( + IcebergColumnHandle columnHandle, + @Nullable Object lowerBound, + @Nullable Object upperBound, + boolean mayContainNulls) + { + io.trino.spi.type.Type type = columnHandle.getType(); + Type icebergType = toIcebergType(type, columnHandle.getColumnIdentity()); + if (lowerBound == null && upperBound == null) { + return Domain.create(ValueSet.all(type), mayContainNulls); + } + + Range statisticsRange; + if (lowerBound != null && upperBound != null) { + statisticsRange = Range.range( + type, + convertIcebergValueToTrino(icebergType, lowerBound), + true, + convertIcebergValueToTrino(icebergType, upperBound), + true); + } + else if (upperBound != null) { + statisticsRange = Range.lessThanOrEqual(type, convertIcebergValueToTrino(icebergType, upperBound)); + } + else { + statisticsRange = Range.greaterThanOrEqual(type, convertIcebergValueToTrino(icebergType, lowerBound)); + } + return Domain.create(ValueSet.ofRanges(statisticsRange), mayContainNulls); + } + + static boolean partitionMatchesConstraint( + Set identityPartitionColumns, + Supplier> partitionValues, + Constraint constraint) + { + // We use Constraint just to pass functional predicate here from DistributedExecutionPlanner + verify(constraint.getSummary().isAll()); + + if (constraint.predicate().isEmpty() || + intersection(constraint.getPredicateColumns().orElseThrow(), identityPartitionColumns).isEmpty()) { + return true; + } + return constraint.predicate().get().test(partitionValues.get()); + } + + @VisibleForTesting + static boolean partitionMatchesPredicate( + Set identityPartitionColumns, + Supplier> partitionValues, + TupleDomain dynamicFilterPredicate) + { + if (dynamicFilterPredicate.isNone()) { + return false; + } + Map domains = dynamicFilterPredicate.getDomains().orElseThrow(); + + for (IcebergColumnHandle partitionColumn : identityPartitionColumns) { + Domain allowedDomain = domains.get(partitionColumn); + if (allowedDomain != null) { + if (!allowedDomain.includesNullableValue(partitionValues.get().get(partitionColumn).getValue())) { + return false; + } + } + } + return true; + } + + private AggregateIcebergSplit toIcebergSplit(long totalCount) + { + return new AggregateIcebergSplit( + ImmutableList.of(), + totalCount); + } + + private static Domain getPathDomain(TupleDomain effectivePredicate) + { + IcebergColumnHandle pathColumn = pathColumnHandle(); + Domain domain = effectivePredicate.getDomains().orElseThrow(() -> new IllegalArgumentException("Unexpected NONE tuple domain")) + .get(pathColumn); + if (domain == null) { + return Domain.all(pathColumn.getType()); + } + return domain; + } + + private static Domain getFileModifiedTimePathDomain(TupleDomain effectivePredicate) + { + IcebergColumnHandle fileModifiedTimeColumn = fileModifiedTimeColumnHandle(); + Domain domain = effectivePredicate.getDomains().orElseThrow(() -> new IllegalArgumentException("Unexpected NONE tuple domain")) + .get(fileModifiedTimeColumn); + if (domain == null) { + return Domain.all(fileModifiedTimeColumn.getType()); + } + return domain; + } + + private static long getDeletedFilesCount(FileScanTask fileScanTask) + { + AtomicLong deletedCount = new AtomicLong(); + fileScanTask.deletes().stream().forEach(deleteFile -> { + deletedCount.getAndAdd(deleteFile.recordCount()); + }); + + return deletedCount.longValue(); + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/ImplementCountAll.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/ImplementCountAll.java new file mode 100644 index 000000000000..8a2b9d368ed3 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/ImplementCountAll.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg.aggregation; + +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.spi.connector.AggregateFunction; + +import java.util.List; +import java.util.Optional; + +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.arguments; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.outputType; +import static io.trino.spi.type.BigintType.BIGINT; + +public class ImplementCountAll + implements AggregateFunctionRule +{ + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("count")) + .with(arguments().equalTo(List.of())) + .with(outputType().equalTo(BIGINT)); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + return Optional.of(new AggregateExpression("count", "*")); + } +} diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index d5171456180b..91bd7d89ddba 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -450,35 +450,103 @@ private void testSelectOrPartitionedByTime(boolean partitioned) @Test public void testPartitionByTimestamp() { - testSelectOrPartitionedByTimestamp(true); + testSelectOrPartitionedByTimestamp(true, false); + } + + @Test + public void testPartitionByTimestampWithAggregationPushdown() + { + testSelectOrPartitionedByTimestamp(true, true); } @Test public void testSelectByTimestamp() { - testSelectOrPartitionedByTimestamp(false); + testSelectOrPartitionedByTimestamp(false, false); + } + + @Test + public void testSelectByTimestampWithAggregationPushdown() + { + testSelectOrPartitionedByTimestamp(false, true); } - private void testSelectOrPartitionedByTimestamp(boolean partitioned) + private void testSelectOrPartitionedByTimestamp(boolean partitioned, boolean enableAggregationPushdown) { + Session clientSession = getSession(); + if (enableAggregationPushdown) { + clientSession = sessionWithAggregationPushdown(); + } + String tableName = format("test_%s_by_timestamp", partitioned ? "partitioned" : "selected"); - assertUpdate(format("CREATE TABLE %s (_timestamp timestamp(6)) %s", + assertUpdate(clientSession, format("CREATE TABLE %s (_timestamp timestamp(6)) %s", tableName, partitioned ? "WITH (partitioning = ARRAY['_timestamp'])" : "")); @Language("SQL") String select1 = "SELECT TIMESTAMP '2017-05-01 10:12:34' _timestamp"; @Language("SQL") String select2 = "SELECT TIMESTAMP '2017-10-01 10:12:34' _timestamp"; @Language("SQL") String select3 = "SELECT TIMESTAMP '2018-05-01 10:12:34' _timestamp"; - assertUpdate(format("INSERT INTO %s %s", tableName, select1), 1); - assertUpdate(format("INSERT INTO %s %s", tableName, select2), 1); - assertUpdate(format("INSERT INTO %s %s", tableName, select3), 1); - assertQuery(format("SELECT COUNT(*) from %s", tableName), "SELECT 3"); - - assertQuery(format("SELECT * from %s WHERE _timestamp = TIMESTAMP '2017-05-01 10:12:34'", tableName), select1); - assertQuery(format("SELECT * from %s WHERE _timestamp < TIMESTAMP '2017-06-01 10:12:34'", tableName), select1); - assertQuery(format("SELECT * from %s WHERE _timestamp = TIMESTAMP '2017-10-01 10:12:34'", tableName), select2); - assertQuery(format("SELECT * from %s WHERE _timestamp > TIMESTAMP '2017-06-01 10:12:34' AND _timestamp < TIMESTAMP '2018-05-01 10:12:34'", tableName), select2); - assertQuery(format("SELECT * from %s WHERE _timestamp = TIMESTAMP '2018-05-01 10:12:34'", tableName), select3); - assertQuery(format("SELECT * from %s WHERE _timestamp > TIMESTAMP '2018-01-01 10:12:34'", tableName), select3); - assertUpdate("DROP TABLE " + tableName); + assertUpdate(clientSession, format("INSERT INTO %s %s", tableName, select1), 1); + assertUpdate(clientSession, format("INSERT INTO %s %s", tableName, select2), 1); + assertUpdate(clientSession, format("INSERT INTO %s %s", tableName, select3), 1); + assertQuery(clientSession, format("SELECT COUNT(*) from %s", tableName), "SELECT 3"); + + assertQuery(clientSession, format("SELECT * from %s WHERE _timestamp = TIMESTAMP '2017-05-01 10:12:34'", tableName), select1); + assertQuery(clientSession, format("SELECT * from %s WHERE _timestamp < TIMESTAMP '2017-06-01 10:12:34'", tableName), select1); + assertQuery(clientSession, format("SELECT * from %s WHERE _timestamp = TIMESTAMP '2017-10-01 10:12:34'", tableName), select2); + assertQuery(clientSession, format("SELECT * from %s WHERE _timestamp > TIMESTAMP '2017-06-01 10:12:34' AND _timestamp < TIMESTAMP '2018-05-01 10:12:34'", tableName), select2); + assertQuery(clientSession, format("SELECT * from %s WHERE _timestamp = TIMESTAMP '2018-05-01 10:12:34'", tableName), select3); + assertQuery(clientSession, format("SELECT * from %s WHERE _timestamp > TIMESTAMP '2018-01-01 10:12:34'", tableName), select3); + dropTable(tableName); + } + + @Test + public void testMultiplePartitionedWithAggregationPushdown() + { + Session clientSessionWithAggregationPushdown = sessionWithAggregationPushdown(); + + String tableName = format("test_%s_with_aggregation_pushdown", "multiple_partition"); + assertUpdate(clientSessionWithAggregationPushdown, format("CREATE TABLE %s (userid int, country varchar, event_date date, state varchar) %s", + tableName, "WITH (partitioning = ARRAY['event_date', 'country'])")); + + assertUpdate(format("INSERT INTO %s VALUES (1, 'USA', DATE '2022-11-01', 'California'), (2, 'USA', DATE '2022-11-01', 'Ohio')", tableName), 2); + assertUpdate(format("INSERT INTO %s VALUES (3, 'FRA', DATE '2022-11-02', 'Brittany'), (4, 'USA', DATE '2022-11-02', 'NJ')", tableName), 2); + assertUpdate(format("INSERT INTO %s VALUES (5, 'USA', DATE '2022-11-04', 'Nevada')", tableName), 1); + + assertQuery(clientSessionWithAggregationPushdown, format("SELECT COUNT(*) from %s", tableName), "SELECT 5"); + assertQuery(format("SELECT COUNT(*) from %s", tableName), "SELECT 5"); + + assertThat(query(clientSessionWithAggregationPushdown, format("SELECT userid, country, event_date, state FROM %s WHERE event_date = DATE '2022-11-01'", tableName))) + .matches("VALUES (1, VARCHAR 'USA', DATE '2022-11-01', VARCHAR 'California'), (2, VARCHAR 'USA', DATE '2022-11-01', VARCHAR 'Ohio')"); + + assertThat(query(format("SELECT userid, country, event_date, state FROM %s WHERE event_date = DATE '2022-11-01'", tableName))) + .matches("VALUES (1, VARCHAR 'USA', DATE '2022-11-01', VARCHAR 'California'), (2, 'USA', DATE '2022-11-01', 'Ohio')"); + + assertQuery(clientSessionWithAggregationPushdown, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 2"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 2"); + + assertQuery(clientSessionWithAggregationPushdown, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA'", tableName), "SELECT 2"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA'", tableName), "SELECT 2"); + + // non partition delete + assertUpdate(format("DELETE FROM %s WHERE event_date = DATE '2022-11-02' AND state = 'Brittany'", tableName), 1); + + assertQuery(clientSessionWithAggregationPushdown, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02'", tableName), "SELECT 1"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02'", tableName), "SELECT 1"); + + assertQuery(clientSessionWithAggregationPushdown, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA' AND state = 'California'", tableName), "SELECT 1"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA' AND state = 'California'", tableName), "SELECT 1"); + + assertQuery(clientSessionWithAggregationPushdown, format("SELECT COUNT(*) from %s WHERE event_date >= DATE '2022-11-01' and country = 'USA'", tableName), "SELECT 4"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date >= DATE '2022-11-01' and country = 'USA'", tableName), "SELECT 4"); + + assertUpdate(format("INSERT INTO %s VALUES (6, 'FRA', DATE '2022-11-05', 'Brittany'), (7, 'USA', DATE '2022-11-05', 'NJ')", tableName), 2); + + // partition delete + assertUpdate(format("DELETE FROM %s WHERE event_date = DATE '2022-11-05' AND country = 'USA'", tableName), 1); + + assertQuery(clientSessionWithAggregationPushdown, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-05'", tableName), "SELECT 1"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-05'", tableName), "SELECT 1"); + + dropTable(tableName); } @Test @@ -8185,4 +8253,19 @@ private void assertQueryIdStored(String tableName, QueryId queryId) assertThat(getFieldFromLatestSnapshotSummary(tableName, TRINO_QUERY_ID_NAME)) .isEqualTo(queryId.toString()); } + + private Session sessionWithAggregationPushdown() + { + return Session.builder(getSession()) + // Enable aggregation pushdown + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), IcebergSessionProperties.AGGREGATION_PUSHDOWN_ENABLED, "true") + .build(); + } + + private void dropTable(String table) + { + Session session = getSession(); + assertUpdate(session, "DROP TABLE " + table); + assertThat(getQueryRunner().tableExists(session, table)).isFalse(); + } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java index c4f61e49d8e7..2db5bb1bd2f5 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergConfig.java @@ -67,7 +67,8 @@ public void testDefaults() .setRegisterTableProcedureEnabled(false) .setSortedWritingEnabled(true) .setQueryPartitionFilterRequired(false) - .setSplitManagerThreads(Runtime.getRuntime().availableProcessors() * 2)); + .setSplitManagerThreads(Runtime.getRuntime().availableProcessors() * 2) + .setAggregationPushdownEnabled(false)); } @Test @@ -99,6 +100,7 @@ public void testExplicitPropertyMappings() .put("iceberg.sorted-writing-enabled", "false") .put("iceberg.query-partition-filter-required", "true") .put("iceberg.split-manager-threads", "42") + .put("iceberg.aggregation-pushdown.enabled", "true") .buildOrThrow(); IcebergConfig expected = new IcebergConfig() @@ -126,7 +128,8 @@ public void testExplicitPropertyMappings() .setRegisterTableProcedureEnabled(true) .setSortedWritingEnabled(false) .setQueryPartitionFilterRequired(true) - .setSplitManagerThreads(42); + .setSplitManagerThreads(42) + .setAggregationPushdownEnabled(true); assertFullMapping(properties, expected); } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java index b77be164b939..8ce38a8b2f0f 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java @@ -210,6 +210,26 @@ public void testV2TableWithEqualityDelete() assertQuery("SELECT nationkey, comment FROM " + tableName, "SELECT nationkey, comment FROM nation WHERE regionkey != 1"); } + @Test + public void testV2TableWithEqualityDeleteWithAggregationPushdownEnabled() + throws Exception + { + Session aggregationEnabledSession = sessionWithAggregationPushdown(); + + String tableName = "test_v2_equality_delete" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM tpch.tiny.nation", 25); + assertQuery("SELECT count(*) FROM " + tableName, "SELECT 25"); + Table icebergTable = updateTableToV2(tableName); + writeEqualityDeleteToNationTable(icebergTable, Optional.of(icebergTable.spec()), Optional.of(new PartitionData(new Long[]{1L}))); + assertQuery(aggregationEnabledSession, "SELECT * FROM " + tableName, "SELECT * FROM nation WHERE regionkey != 1"); + + assertQuery(aggregationEnabledSession, "SELECT count(*) FROM " + tableName, "SELECT 20"); + assertQuery("SELECT count(*) FROM " + tableName, "SELECT 20"); + + // natiokey is before the equality delete column in the table schema, comment is after + assertQuery(aggregationEnabledSession, "SELECT nationkey, comment FROM " + tableName, "SELECT nationkey, comment FROM nation WHERE regionkey != 1"); + } + @Test public void testV2TableWithEqualityDeleteDifferentColumnOrder() throws Exception @@ -684,6 +704,75 @@ public void testDeletingEntireFileWithNonTupleDomainConstraint() assertThat(this.loadTable(tableName).newScan().planFiles()).hasSize(2); } + @Test + public void testMultiplePartitionedWithAggregationPushdown() + { + Session clientSession = sessionWithAggregationPushdown(); + + String tableName = format("test_%s_with_aggregation_pushdown", "multiple_partition"); + assertUpdate(format("CREATE TABLE %s (userid int, country varchar, event_date date, state varchar) %s", + tableName, "WITH (partitioning = ARRAY['event_date', 'country'])")); + + assertUpdate(format("INSERT INTO %s VALUES (1, 'USA', DATE '2022-11-01', 'California'), (2, 'USA', DATE '2022-11-01', 'Ohio')", tableName), 2); + assertUpdate(format("INSERT INTO %s VALUES (3, 'IND', DATE '2022-11-01', 'Delhi'), (4, 'IND', DATE '2022-11-01', 'MP')", tableName), 2); + assertUpdate(format("INSERT INTO %s VALUES (5, 'FRA', DATE '2022-11-02', 'Brittany'), (6, 'USA', DATE '2022-11-02', 'Corsica')", tableName), 2); + assertUpdate(format("INSERT INTO %s VALUES (7, 'USA', DATE '2022-11-04', 'Nevada')", tableName), 1); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s", tableName), "SELECT 7"); + assertQuery(format("SELECT COUNT(*) from %s", tableName), "SELECT 7"); + + assertThat(query(clientSession, format("SELECT userid, country, event_date, state FROM %s WHERE event_date = DATE '2022-11-01' order by userid", tableName))) + .matches("VALUES (1, VARCHAR 'USA', DATE '2022-11-01', VARCHAR 'California'), (2, VARCHAR 'USA', DATE '2022-11-01', VARCHAR 'Ohio')" + + ", (3, VARCHAR 'IND', DATE '2022-11-01', VARCHAR 'Delhi'), (4, VARCHAR 'IND', DATE '2022-11-01', VARCHAR 'MP')"); + + assertThat(query(format("SELECT userid, country, event_date, state FROM %s WHERE event_date = DATE '2022-11-01' order by userid", tableName))) + .matches("VALUES (1, VARCHAR 'USA', DATE '2022-11-01', VARCHAR 'California'), (2, VARCHAR 'USA', DATE '2022-11-01', VARCHAR 'Ohio')" + + ", (3, VARCHAR 'IND', DATE '2022-11-01', VARCHAR 'Delhi'), (4, VARCHAR 'IND', DATE '2022-11-01', VARCHAR 'MP')"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 4"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 4"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA'", tableName), "SELECT 2"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA'", tableName), "SELECT 2"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA' AND state = 'California'", tableName), "SELECT 1"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country = 'USA' AND state = 'California'", tableName), "SELECT 1"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country IN ('USA', 'IND')", tableName), "SELECT 4"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND country IN ('USA', 'IND')", tableName), "SELECT 4"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date >= DATE '2022-11-01' and country = 'USA'", tableName), "SELECT 4"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date >= DATE '2022-11-01' and country = 'USA'", tableName), "SELECT 4"); + + assertUpdate(format("DELETE FROM %s WHERE event_date = DATE '2022-11-01' AND country ='USA' AND state = 'California'", tableName), 1); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 3"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 3"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND state = 'California'", tableName), "SELECT 0"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01' AND state = 'California'", tableName), "SELECT 0"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02'", tableName), "SELECT 2"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02'", tableName), "SELECT 2"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-04'", tableName), "SELECT 1"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-04'", tableName), "SELECT 1"); + + assertUpdate(format("DELETE FROM %s WHERE event_date = DATE '2022-11-01' AND country ='USA'", tableName), 1); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 2"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 2"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 2"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-01'", tableName), "SELECT 2"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02' AND country = 'FRA' AND state = 'Brittany'", tableName), "SELECT 1"); + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02' AND country = 'FRA' AND state = 'Brittany'", tableName), "SELECT 1"); + + assertQuery(clientSession, format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02' AND country = 'USA' AND state = 'Corsica'", tableName), "SELECT 1"); + assertQuery(format("SELECT COUNT(*) from %s WHERE event_date = DATE '2022-11-02' AND country = 'USA' AND state = 'Corsica'", tableName), "SELECT 1"); + } + @Test public void testDeletingEntireFileWithMultipleSplits() { @@ -1025,4 +1114,12 @@ private List getActiveFiles(String tableName) .map(String.class::cast) .collect(toImmutableList()); } + + private Session sessionWithAggregationPushdown() + { + return Session.builder(getSession()) + // Enable aggregation pushdown + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), IcebergSessionProperties.AGGREGATION_PUSHDOWN_ENABLED, "true") + .build(); + } } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java index 7b7ab823d6af..6f3f2490435f 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java @@ -1750,6 +1750,56 @@ public void testTrinoReadsSparkRowLevelDeletes(StorageFormat tableStorageFormat, onSpark().executeQuery("DROP TABLE " + sparkTableName); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "tableFormatWithDeleteFormat") + public void testTrinoReadsSparkRowLevelDeletesWithAggregatePushdown(StorageFormat tableStorageFormat, StorageFormat deleteFileStorageFormat) + { + String tableName = format("test_trino_reads_with_aggregate_pushdown_spark_row_level_deletes_%s_%s_%s", tableStorageFormat.name(), deleteFileStorageFormat.name(), randomNameSuffix()); + String sparkTableName = sparkTableName(tableName); + String trinoTableName = trinoTableName(tableName); + + onSpark().executeQuery("CREATE TABLE " + sparkTableName + "(a INT, b INT) " + + "USING ICEBERG PARTITIONED BY (b) " + + "TBLPROPERTIES ('format-version'='2', 'write.delete.mode'='merge-on-read'," + + "'write.format.default'='" + tableStorageFormat.name() + "'," + + "'write.delete.format.default'='" + deleteFileStorageFormat.name() + "')"); + onSpark().executeQuery("INSERT INTO " + sparkTableName + " VALUES (1, 2), (2, 2), (3, 2), (11, 12), (12, 12), (13, 12)"); + // Spark inserts may create multiple files. rewrite_data_files ensures it is compacted to one file so a row level delete occurs. + onSpark().executeQuery("CALL " + SPARK_CATALOG + ".system.rewrite_data_files(table=>'" + TEST_SCHEMA_NAME + "." + tableName + "', options => map('min-input-files','1'))"); + // Delete one row in a file + onSpark().executeQuery("DELETE FROM " + sparkTableName + " WHERE a = 13"); + // Delete an entire partition + onSpark().executeQuery("DELETE FROM " + sparkTableName + " WHERE b = 2"); + + List expected = ImmutableList.of(row(11, 12), row(12, 12)); + + onTrino().executeQuery("SET SESSION iceberg.aggregation_pushdown_enabled = true"); + + assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(expected); + assertThat(onSpark().executeQuery("SELECT * FROM " + sparkTableName)).containsOnly(expected); + + assertThat(onTrino().executeQuery("SELECT count(*) FROM " + trinoTableName)).containsOnly(ImmutableList.of(row(2))); + assertThat(onSpark().executeQuery("SELECT count(*) FROM " + sparkTableName)).containsOnly(ImmutableList.of(row(2))); + assertThat(onTrino().executeQuery("SELECT count(*) FROM " + trinoTableName + " WHERE b = 12")).containsOnly(ImmutableList.of(row(2))); + assertThat(onSpark().executeQuery("SELECT count(*) FROM " + sparkTableName + " WHERE b = 12")).containsOnly(ImmutableList.of(row(2))); + assertThat(onTrino().executeQuery("SELECT count(*) FROM " + trinoTableName + " WHERE a IN (11, 12)")).containsOnly(ImmutableList.of(row(2))); + assertThat(onSpark().executeQuery("SELECT count(*) FROM " + sparkTableName + " WHERE a IN (11, 12)")).containsOnly(ImmutableList.of(row(2))); + + // Delete to a file that already has deleted rows + onSpark().executeQuery("DELETE FROM " + sparkTableName + " WHERE a = 12"); + expected = ImmutableList.of(row(11, 12)); + assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(expected); + assertThat(onSpark().executeQuery("SELECT * FROM " + sparkTableName)).containsOnly(expected); + + assertThat(onTrino().executeQuery("SELECT count(*) FROM " + trinoTableName)).containsOnly(ImmutableList.of(row(1))); + assertThat(onSpark().executeQuery("SELECT count(*) FROM " + sparkTableName)).containsOnly(ImmutableList.of(row(1))); + assertThat(onTrino().executeQuery("SELECT count(*) FROM " + trinoTableName + " WHERE b = 12")).containsOnly(ImmutableList.of(row(1))); + assertThat(onSpark().executeQuery("SELECT count(*) FROM " + sparkTableName + " WHERE b = 12")).containsOnly(ImmutableList.of(row(1))); + assertThat(onTrino().executeQuery("SELECT count(*) FROM " + trinoTableName + " WHERE a IN (11, 12)")).containsOnly(ImmutableList.of(row(1))); + assertThat(onSpark().executeQuery("SELECT count(*) FROM " + sparkTableName + " WHERE a IN (11, 12)")).containsOnly(ImmutableList.of(row(1))); + + onSpark().executeQuery("DROP TABLE " + sparkTableName); + } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS, ICEBERG_REST, ICEBERG_JDBC}, dataProvider = "tableFormatWithDeleteFormat") public void testTrinoReadsSparkRowLevelDeletesWithRowTypes(StorageFormat tableStorageFormat, StorageFormat deleteFileStorageFormat) {