diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/PartitionTable.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/PartitionTable.java index ccde9fbdc89f..7ebe6d0aa74e 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/PartitionTable.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/PartitionTable.java @@ -14,6 +14,7 @@ package io.trino.plugin.iceberg; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ColumnMetadata; @@ -37,16 +38,20 @@ import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; import org.apache.iceberg.util.StructLikeWrapper; import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.stream.IntStream; import java.util.stream.Stream; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -59,6 +64,7 @@ import static io.trino.spi.type.TypeUtils.writeNativeValue; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toSet; +import static java.util.stream.Collectors.toUnmodifiableSet; public class PartitionTable implements SystemTable @@ -67,9 +73,9 @@ public class PartitionTable private final Table icebergTable; private final Optional snapshotId; private final Map idToTypeMapping; - private final List nonPartitionPrimitiveColumns; - private final Optional partitionColumnType; - private final List partitionColumnTypes; + private final List nonPartitionPrimitiveColumns; + private final Optional partitionColumnType; + private final List partitionFields; private final Optional dataColumnType; private final List columnMetricTypes; private final List resultTypes; @@ -82,21 +88,14 @@ public PartitionTable(SchemaTableName tableName, TypeManager typeManager, Table this.snapshotId = requireNonNull(snapshotId, "snapshotId is null"); this.idToTypeMapping = primitiveFieldTypes(icebergTable.schema()); - List columns = icebergTable.schema().columns(); - List partitionFields = icebergTable.spec().fields(); + List columns = icebergTable.schema().columns(); + this.partitionFields = getAllPartitionFields(icebergTable); ImmutableList.Builder columnMetadataBuilder = ImmutableList.builder(); this.partitionColumnType = getPartitionColumnType(partitionFields, icebergTable.schema()); - if (partitionColumnType.isPresent()) { - columnMetadataBuilder.add(new ColumnMetadata("partition", partitionColumnType.get())); - this.partitionColumnTypes = partitionColumnType.get().getFields().stream() - .map(RowType.Field::getType) - .collect(toImmutableList()); - } - else { - this.partitionColumnTypes = ImmutableList.of(); - } + partitionColumnType.ifPresent(icebergPartitionColumn -> + columnMetadataBuilder.add(new ColumnMetadata("partition", icebergPartitionColumn.rowType))); Stream.of("record_count", "file_count", "total_size") .forEach(metric -> columnMetadataBuilder.add(new ColumnMetadata(metric, BIGINT))); @@ -140,20 +139,53 @@ public ConnectorTableMetadata getTableMetadata() return connectorTableMetadata; } - private Optional getPartitionColumnType(List fields, Schema schema) + private static List getAllPartitionFields(Table icebergTable) { + Set existingColumnsIds = icebergTable.schema() + .columns().stream() + .map(NestedField::fieldId) + .collect(toUnmodifiableSet()); + + List visiblePartitionFields = icebergTable.specs() + .values().stream() + .flatMap(partitionSpec -> partitionSpec.fields().stream()) + // skip columns that were dropped + .filter(partitionField -> existingColumnsIds.contains(partitionField.sourceId())) + .collect(toImmutableList()); + + return filterOutDuplicates(visiblePartitionFields); + } + + private static List filterOutDuplicates(List visiblePartitionFields) + { + Set alreadyExistingFieldIds = new HashSet<>(); + List result = new ArrayList<>(); + for (PartitionField partitionField : visiblePartitionFields) { + if (!alreadyExistingFieldIds.contains(partitionField.fieldId())) { + alreadyExistingFieldIds.add(partitionField.fieldId()); + result.add(partitionField); + } + } + return result; + } + + private Optional getPartitionColumnType(List fields, Schema schema) + { + if (fields.isEmpty()) { + return Optional.empty(); + } List partitionFields = fields.stream() .map(field -> RowType.field( field.name(), toTrinoType(field.transform().getResultType(schema.findType(field.sourceId())), typeManager))) .collect(toImmutableList()); - if (partitionFields.isEmpty()) { - return Optional.empty(); - } - return Optional.of(RowType.from(partitionFields)); + List fieldIds = fields.stream() + .map(PartitionField::fieldId) + .collect(toImmutableList()); + return Optional.of(new IcebergPartitionColumn(RowType.from(partitionFields), fieldIds)); } - private Optional getMetricsColumnType(List columns) + private Optional getMetricsColumnType(List columns) { List metricColumns = columns.stream() .map(column -> RowType.field( @@ -180,21 +212,22 @@ public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, Connect .useSnapshot(snapshotId.get()) .includeColumnStats(); // TODO make the cursor lazy - return buildRecordCursor(getStatisticsByPartition(tableScan), icebergTable.spec().fields()); + return buildRecordCursor(getStatisticsByPartition(tableScan)); } - private Map getStatisticsByPartition(TableScan tableScan) + private Map getStatisticsByPartition(TableScan tableScan) { try (CloseableIterable fileScanTasks = tableScan.planFiles()) { - Map partitions = new HashMap<>(); + Map partitions = new HashMap<>(); for (FileScanTask fileScanTask : fileScanTasks) { DataFile dataFile = fileScanTask.file(); Types.StructType structType = fileScanTask.spec().partitionType(); StructLike partitionStruct = dataFile.partition(); StructLikeWrapper partitionWrapper = StructLikeWrapper.forType(structType).set(partitionStruct); + StructLikeWrapperWithFieldIdToIndex structLikeWrapperWithFieldIdToIndex = new StructLikeWrapperWithFieldIdToIndex(partitionWrapper, structType); partitions.computeIfAbsent( - partitionWrapper, + structLikeWrapperWithFieldIdToIndex, ignored -> new IcebergStatistics.Builder(icebergTable.schema().columns(), typeManager)) .acceptDataFile(dataFile, fileScanTask.spec()); } @@ -207,31 +240,40 @@ private Map getStatisticsByPartition(Table } } - private RecordCursor buildRecordCursor(Map partitionStatistics, List partitionFields) + private RecordCursor buildRecordCursor(Map partitionStatistics) { - List partitionTypes = partitionTypes(partitionFields); + List partitionTypes = partitionTypes(); List> partitionColumnClass = partitionTypes.stream() .map(type -> type.typeId().javaClass()) .collect(toImmutableList()); ImmutableList.Builder> records = ImmutableList.builder(); - for (Map.Entry partitionEntry : partitionStatistics.entrySet()) { - StructLikeWrapper partitionStruct = partitionEntry.getKey(); + for (Map.Entry partitionEntry : partitionStatistics.entrySet()) { + StructLikeWrapperWithFieldIdToIndex partitionStruct = partitionEntry.getKey(); IcebergStatistics icebergStatistics = partitionEntry.getValue(); List row = new ArrayList<>(); // add data for partition columns partitionColumnType.ifPresent(partitionColumnType -> { - BlockBuilder partitionRowBlockBuilder = partitionColumnType.createBlockBuilder(null, 1); + BlockBuilder partitionRowBlockBuilder = partitionColumnType.rowType.createBlockBuilder(null, 1); BlockBuilder partitionBlockBuilder = partitionRowBlockBuilder.beginBlockEntry(); + List partitionColumnTypes = partitionColumnType.rowType.getFields().stream() + .map(RowType.Field::getType) + .collect(toImmutableList()); for (int i = 0; i < partitionColumnTypes.size(); i++) { - io.trino.spi.type.Type trinoType = partitionColumnType.getFields().get(i).getType(); - Object value = convertIcebergValueToTrino(partitionTypes.get(i), partitionStruct.get().get(i, partitionColumnClass.get(i))); + io.trino.spi.type.Type trinoType = partitionColumnType.rowType.getFields().get(i).getType(); + Object value = null; + Integer fieldId = partitionColumnType.fieldIds.get(i); + if (partitionStruct.fieldIdToIndex.containsKey(fieldId)) { + value = convertIcebergValueToTrino( + partitionTypes.get(i), + partitionStruct.structLikeWrapper.get().get(partitionStruct.fieldIdToIndex.get(fieldId), partitionColumnClass.get(i))); + } writeNativeValue(trinoType, partitionBlockBuilder, value); } partitionRowBlockBuilder.closeEntry(); - row.add(partitionColumnType.getObject(partitionRowBlockBuilder, 0)); + row.add(partitionColumnType.rowType.getObject(partitionRowBlockBuilder, 0)); }); // add the top level metrics. @@ -268,7 +310,7 @@ private RecordCursor buildRecordCursor(Map return new InMemoryRecordSet(resultTypes, records.build()).cursor(); } - private List partitionTypes(List partitionFields) + private List partitionTypes() { ImmutableList.Builder partitionTypeBuilder = ImmutableList.builder(); for (PartitionField partitionField : partitionFields) { @@ -292,4 +334,70 @@ private static Block getColumnMetricBlock(RowType columnMetricType, Object min, rowBlockBuilder.closeEntry(); return columnMetricType.getObject(rowBlockBuilder, 0); } + + private static class StructLikeWrapperWithFieldIdToIndex + { + private final StructLikeWrapper structLikeWrapper; + private final Map fieldIdToIndex; + + public StructLikeWrapperWithFieldIdToIndex(StructLikeWrapper structLikeWrapper, Types.StructType structType) + { + this.structLikeWrapper = structLikeWrapper; + ImmutableMap.Builder fieldIdToIndex = ImmutableMap.builder(); + List fields = structType.fields(); + IntStream.range(0, fields.size()) + .forEach(i -> fieldIdToIndex.put(fields.get(i).fieldId(), i)); + this.fieldIdToIndex = fieldIdToIndex.buildOrThrow(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + StructLikeWrapperWithFieldIdToIndex that = (StructLikeWrapperWithFieldIdToIndex) o; + return Objects.equals(structLikeWrapper, that.structLikeWrapper) && Objects.equals(fieldIdToIndex, that.fieldIdToIndex); + } + + @Override + public int hashCode() + { + return Objects.hash(structLikeWrapper, fieldIdToIndex); + } + } + + private static class IcebergPartitionColumn + { + private final RowType rowType; + private final List fieldIds; + + public IcebergPartitionColumn(RowType rowType, List fieldIds) + { + this.rowType = rowType; + this.fieldIds = fieldIds; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IcebergPartitionColumn that = (IcebergPartitionColumn) o; + return Objects.equals(rowType, that.rowType) && Objects.equals(fieldIds, that.fieldIds); + } + + @Override + public int hashCode() + { + return Objects.hash(rowType, fieldIds); + } + } } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java index 2e9b18a7d8b5..f7b37a3f1399 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java @@ -14,6 +14,7 @@ package io.trino.tests.product.iceberg; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Streams; import io.airlift.concurrent.MoreFutures; import io.trino.tempto.ProductTest; @@ -36,6 +37,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; @@ -63,6 +66,7 @@ import static java.util.Arrays.asList; import static java.util.Locale.ENGLISH; import static java.util.concurrent.TimeUnit.SECONDS; +import static java.util.stream.Collectors.toUnmodifiableSet; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -1926,6 +1930,118 @@ public void testUpdateOnPartitionColumn() onSpark().executeQuery("DROP TABLE " + sparkTableName); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}) + public void testHandlingPartitionSchemaEvolutionInPartitionMetadata() + { + String baseTableName = "test_handling_partition_schema_evolution_" + randomTableSuffix(); + String trinoTableName = trinoTableName(baseTableName); + String sparkTableName = sparkTableName(baseTableName); + + onTrino().executeQuery(format("CREATE TABLE %s (old_partition_key INT, new_partition_key INT, value date) WITH (PARTITIONING = array['old_partition_key'])", trinoTableName)); + onTrino().executeQuery(format("INSERT INTO %s VALUES (1, 10, date '2022-04-10'), (2, 20, date '2022-05-11'), (3, 30, date '2022-06-12'), (2, 20, date '2022-06-13')", trinoTableName)); + + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1"), + ImmutableMap.of("old_partition_key", "2"), + ImmutableMap.of("old_partition_key", "3"))); + + onSpark().executeQuery(format("ALTER TABLE %s DROP PARTITION FIELD old_partition_key", sparkTableName)); + onSpark().executeQuery(format("ALTER TABLE %s ADD PARTITION FIELD new_partition_key", sparkTableName)); + + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null"))); + + onTrino().executeQuery(format("INSERT INTO %s VALUES (4, 40, date '2022-08-15')", trinoTableName)); + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "40"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null"))); + + onSpark().executeQuery(format("ALTER TABLE %s DROP PARTITION FIELD new_partition_key", sparkTableName)); + onSpark().executeQuery(format("ALTER TABLE %s ADD PARTITION FIELD old_partition_key", sparkTableName)); + + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "40"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null"))); + + onTrino().executeQuery(format("INSERT INTO %s VALUES (5, 50, date '2022-08-15')", trinoTableName)); + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "40"), + ImmutableMap.of("old_partition_key", "5", "new_partition_key", "null"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null"))); + + onSpark().executeQuery(format("ALTER TABLE %s DROP PARTITION FIELD old_partition_key", sparkTableName)); + onSpark().executeQuery(format("ALTER TABLE %s ADD PARTITION FIELD days(value)", sparkTableName)); + + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null", "value_day", "null"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null", "value_day", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "40", "value_day", "null"), + ImmutableMap.of("old_partition_key", "5", "new_partition_key", "null", "value_day", "null"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null", "value_day", "null"))); + + onTrino().executeQuery(format("INSERT INTO %s VALUES (6, 60, date '2022-08-16')", trinoTableName)); + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null", "value_day", "null"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null", "value_day", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "40", "value_day", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "null", "value_day", "2022-08-16"), + ImmutableMap.of("old_partition_key", "5", "new_partition_key", "null", "value_day", "null"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null", "value_day", "null"))); + + onSpark().executeQuery(format("ALTER TABLE %s DROP PARTITION FIELD value_day", sparkTableName)); + onSpark().executeQuery(format("ALTER TABLE %s ADD PARTITION FIELD months(value)", sparkTableName)); + + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "40", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "null", "value_day", "2022-08-16", "value_month", "null"), + ImmutableMap.of("old_partition_key", "5", "new_partition_key", "null", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null", "value_day", "null", "value_month", "null"))); + + onTrino().executeQuery(format("INSERT INTO %s VALUES (7, 70, date '2022-08-17')", trinoTableName)); + + validatePartitioning(baseTableName, sparkTableName, ImmutableList.of( + ImmutableMap.of("old_partition_key", "1", "new_partition_key", "null", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "null", "value_day", "null", "value_month", "631"), + ImmutableMap.of("old_partition_key", "2", "new_partition_key", "null", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "40", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "null", "new_partition_key", "null", "value_day", "2022-08-16", "value_month", "null"), + ImmutableMap.of("old_partition_key", "5", "new_partition_key", "null", "value_day", "null", "value_month", "null"), + ImmutableMap.of("old_partition_key", "3", "new_partition_key", "null", "value_day", "null", "value_month", "null"))); + } + + private void validatePartitioning(String baseTableName, String sparkTableName, List> expectedValues) + { + List trinoResult = expectedValues.stream().map(m -> + m.entrySet().stream() + .map(entry -> format("%s=%s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining(", ", "{", "}"))) + .collect(toImmutableList()); + List partitioning = onTrino().executeQuery(format("SELECT partition, record_count FROM iceberg.default.\"%s$partitions\"", baseTableName)) + .column(1); + Set partitions = partitioning.stream().map(String::valueOf).collect(toUnmodifiableSet()); + Assertions.assertThat(partitions.size()).isEqualTo(expectedValues.size()); + Assertions.assertThat(partitions).containsAll(trinoResult); + List sparkResult = expectedValues.stream().map(m -> + m.entrySet().stream() + .map(entry -> format("\"%s\":%s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining(",", "{", "}"))) + .collect(toImmutableList()); + partitioning = onSpark().executeQuery(format("SELECT partition from %s.files", sparkTableName)).column(1); + partitions = partitioning.stream().map(String::valueOf).collect(toUnmodifiableSet()); + Assertions.assertThat(partitions.size()).isEqualTo(expectedValues.size()); + Assertions.assertThat(partitions).containsAll(sparkResult); + } + private int calculateMetadataFilesForPartitionedTable(String tableName) { String dataFilePath = onTrino().executeQuery(format("SELECT file_path FROM iceberg.default.\"%s$files\" limit 1", tableName)).row(0).get(0).toString();