diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetColumnIOConverter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetColumnIOConverter.java index a74d2f9366a..14026716e30 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetColumnIOConverter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetColumnIOConverter.java @@ -20,47 +20,52 @@ import io.trino.parquet.RichColumnDescriptor; import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; -import io.trino.spi.type.NamedTypeSignature; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignatureParameter; import org.apache.parquet.io.ColumnIO; import org.apache.parquet.io.GroupColumnIO; import org.apache.parquet.io.PrimitiveColumnIO; -import java.util.List; -import java.util.Locale; import java.util.Optional; import static io.trino.parquet.ParquetTypeUtils.getArrayElementColumn; import static io.trino.parquet.ParquetTypeUtils.getMapKeyValueColumn; import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; import static org.apache.parquet.io.ColumnIOUtil.columnDefinitionLevel; import static org.apache.parquet.io.ColumnIOUtil.columnRepetitionLevel; import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; -public final class ParquetColumnIOConverter +public abstract class ParquetColumnIOConverter { - private ParquetColumnIOConverter() {} + public static ParquetColumnIOConverter withLookupColumnByName() + { + return new ByNameParquetColumnIOConverter(); + } - public static Optional constructField(Type type, ColumnIO columnIO) + public Optional constructField(Context context, Optional columnIO) { - if (columnIO == null) { + if (columnIO.isEmpty()) { return Optional.empty(); } + return constructField(context, columnIO.get()); + } + + protected Optional constructField(Context context, ColumnIO columnIO) + { + requireNonNull(columnIO, "columnIO is null"); + boolean required = columnIO.getType().getRepetition() != OPTIONAL; int repetitionLevel = columnRepetitionLevel(columnIO); int definitionLevel = columnDefinitionLevel(columnIO); + Type type = getType(context); if (type instanceof RowType) { GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - List parameters = type.getTypeParameters(); ImmutableList.Builder> fieldsBuilder = ImmutableList.builder(); - List fields = type.getTypeSignature().getParameters(); boolean structHasParameters = false; - for (int i = 0; i < fields.size(); i++) { - NamedTypeSignature namedTypeSignature = fields.get(i).getNamedTypeSignature(); - String name = namedTypeSignature.getName().get().toLowerCase(Locale.ENGLISH); - Optional field = constructField(parameters.get(i), lookupColumnByName(groupColumnIO, name)); + for (int fieldIndex = 0; fieldIndex < type.getTypeParameters().size(); fieldIndex++) { + Optional field = getRowFieldField(context, groupColumnIO, fieldIndex); structHasParameters |= field.isPresent(); fieldsBuilder.add(field); } @@ -71,26 +76,87 @@ public static Optional constructField(Type type, ColumnIO columnIO) } if (type instanceof MapType) { GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - MapType mapType = (MapType) type; - GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); - if (keyValueColumnIO.getChildrenCount() != 2) { + Optional keyField = getMapKeyField(context, groupColumnIO); + Optional valueField = getMapValueField(context, groupColumnIO); + if (keyField.isEmpty() || valueField.isEmpty()) { return Optional.empty(); } - Optional keyField = constructField(mapType.getKeyType(), keyValueColumnIO.getChild(0)); - Optional valueField = constructField(mapType.getValueType(), keyValueColumnIO.getChild(1)); return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField))); } if (type instanceof ArrayType) { GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - List types = type.getTypeParameters(); - if (groupColumnIO.getChildrenCount() != 1) { + Optional field = getArrayElementField(context, groupColumnIO); + if (field.isEmpty()) { return Optional.empty(); } - Optional field = constructField(types.get(0), getArrayElementColumn(groupColumnIO.getChild(0))); return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field))); } PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO; RichColumnDescriptor column = new RichColumnDescriptor(primitiveColumnIO.getColumnDescriptor(), columnIO.getType().asPrimitiveType()); return Optional.of(new PrimitiveField(type, repetitionLevel, definitionLevel, required, column, primitiveColumnIO.getId())); } + + protected abstract Type getType(Context context); + + protected abstract Optional getArrayElementField(Context arrayContext, GroupColumnIO groupColumnIO); + + protected abstract Optional getMapKeyField(Context mapContext, GroupColumnIO groupColumnIO); + + protected abstract Optional getMapValueField(Context mapContext, GroupColumnIO groupColumnIO); + + protected abstract Optional getRowFieldField(Context rowContext, GroupColumnIO groupColumnIO, int rowFieldIndex); + + private static class ByNameParquetColumnIOConverter + extends ParquetColumnIOConverter + { + @Override + protected Type getType(Type type) + { + return requireNonNull(type, "type is null"); + } + + @Override + protected Optional getArrayElementField(Type arrayType, GroupColumnIO groupColumnIO) + { + if (groupColumnIO.getChildrenCount() != 1) { + return Optional.empty(); + } + return constructField( + ((ArrayType) arrayType).getElementType(), + getArrayElementColumn(groupColumnIO.getChild(0))); + } + + @Override + protected Optional getMapKeyField(Type mapType, GroupColumnIO groupColumnIO) + { + GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); + if (keyValueColumnIO.getChildrenCount() != 2) { + return Optional.empty(); + } + return constructField( + ((MapType) mapType).getKeyType(), + keyValueColumnIO.getChild(0)); + } + + @Override + protected Optional getMapValueField(Type mapType, GroupColumnIO groupColumnIO) + { + GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); + if (keyValueColumnIO.getChildrenCount() != 2) { + return Optional.empty(); + } + return constructField( + ((MapType) mapType).getValueType(), + keyValueColumnIO.getChild(1)); + } + + @Override + protected Optional getRowFieldField(Type rowType, GroupColumnIO groupColumnIO, int rowFieldIndex) + { + RowType.Field rowField = ((RowType) rowType).getFields().get(rowFieldIndex); + String name = rowField.getName().orElseThrow(); + return Optional.ofNullable(lookupColumnByName(groupColumnIO, name.toLowerCase(ENGLISH))) + .flatMap(columnIO -> constructField(rowField.getType(), columnIO)); + } + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index 982371be4aa..1e37c3bf8f7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -92,7 +92,6 @@ import static io.trino.plugin.hive.HiveSessionProperties.isParquetIgnoreStatistics; import static io.trino.plugin.hive.HiveSessionProperties.isParquetUseColumnIndex; import static io.trino.plugin.hive.HiveSessionProperties.isUseParquetColumnNames; -import static io.trino.plugin.hive.parquet.ParquetColumnIOConverter.constructField; import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; import static io.trino.spi.type.BigintType.BIGINT; import static java.lang.String.format; @@ -311,7 +310,9 @@ && predicateMatches(parquetPredicate, block, dataSource, descriptorsByPath, parq internalFields.add(Optional.ofNullable(getParquetType(column, fileSchema, useColumnNames)) .flatMap(field -> { String columnName = useColumnNames ? column.getBaseColumnName() : fileSchema.getFields().get(column.getBaseHiveColumnIndex()).getName(); - return constructField(column.getBaseType(), lookupColumnByName(messageColumn, columnName)); + return ParquetColumnIOConverter.withLookupColumnByName().constructField( + column.getBaseType(), + Optional.ofNullable(lookupColumnByName(messageColumn, columnName))); })); } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ColumnIdentity.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ColumnIdentity.java index 9f1a35aa0da..a1fa573695f 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ColumnIdentity.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ColumnIdentity.java @@ -48,7 +48,7 @@ public ColumnIdentity( this.id = id; this.name = requireNonNull(name, "name is null"); this.typeCategory = requireNonNull(typeCategory, "typeCategory is null"); - this.children = requireNonNull(children, "children is null"); + this.children = ImmutableList.copyOf(requireNonNull(children, "children is null")); checkArgument( children.isEmpty() == (typeCategory == PRIMITIVE), "Children should be empty if and only if column type is primitive"); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 6bf29cf786c..a18ec64702d 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -44,6 +44,7 @@ import io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation; import io.trino.plugin.hive.orc.OrcReaderConfig; import io.trino.plugin.hive.parquet.HdfsParquetDataSource; +import io.trino.plugin.hive.parquet.ParquetColumnIOConverter; import io.trino.plugin.hive.parquet.ParquetPageSource; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.spi.TrinoException; @@ -58,6 +59,9 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.ConnectorIdentity; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; import org.apache.hadoop.conf.Configuration; @@ -71,6 +75,8 @@ import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.FileMetaData; import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.io.ColumnIO; +import org.apache.parquet.io.GroupColumnIO; import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.MessageType; @@ -85,18 +91,22 @@ import java.util.Optional; import java.util.function.Function; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Maps.uniqueIndex; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.orc.OrcReader.INITIAL_BATCH_SIZE; import static io.trino.orc.OrcReader.ProjectedLayout.fullyProjectedLayout; +import static io.trino.parquet.ParquetTypeUtils.getArrayElementColumn; import static io.trino.parquet.ParquetTypeUtils.getColumnIO; import static io.trino.parquet.ParquetTypeUtils.getDescriptors; +import static io.trino.parquet.ParquetTypeUtils.getMapKeyValueColumn; import static io.trino.parquet.ParquetTypeUtils.getParquetTypeByName; import static io.trino.parquet.predicate.PredicateUtils.buildPredicate; import static io.trino.parquet.predicate.PredicateUtils.predicateMatches; -import static io.trino.plugin.hive.parquet.ParquetColumnIOConverter.constructField; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_BAD_DATA; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_CANNOT_OPEN_SPLIT; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_CURSOR_ERROR; @@ -455,10 +465,12 @@ private static ConnectorPageSource createParquetPageSource( .filter(field -> field.getId() != null) .collect(toImmutableMap(field -> field.getId().intValue(), Function.identity())); + // Map by name for a migrated table + boolean mapByName = parquetIdToField.isEmpty(); + List parquetFields = regularColumns.stream() .map(column -> { - if (parquetIdToField.isEmpty()) { - // This is a migrated table + if (mapByName) { return getParquetTypeByName(column.getName(), fileSchema); } return parquetIdToField.get(column.getId()); @@ -504,7 +516,11 @@ private static ConnectorPageSource createParquetPageSource( internalFields.add(Optional.empty()); } else { - internalFields.add(constructField(column.getType(), messageColumnIO.getChild(parquetField.getName()))); + // The top level columns are already mapped by name/id appropriately. + Optional columnIO = Optional.ofNullable(messageColumnIO.getChild(parquetField.getName())); + internalFields.add(mapByName + ? ParquetColumnIOConverter.withLookupColumnByName().constructField(trinoType, columnIO) + : IcebergParquetColumnIOConverter.create().constructField(new FieldContext(trinoType, column.getColumnIdentity()), columnIO)); } } @@ -564,4 +580,119 @@ private static TrinoException handleException(OrcDataSourceId dataSourceId, Exce } return new TrinoException(ICEBERG_CURSOR_ERROR, format("Failed to read ORC file: %s", dataSourceId), exception); } + + private static class IcebergParquetColumnIOConverter + extends ParquetColumnIOConverter + { + static IcebergParquetColumnIOConverter create() + { + return new IcebergParquetColumnIOConverter(); + } + + @Override + protected Type getType(FieldContext context) + { + return context.getType(); + } + + @Override + protected Optional getArrayElementField(FieldContext arrayContext, GroupColumnIO groupColumnIO) + { + checkArgument(arrayContext.getColumnIdentity().getChildren().size() == 1, "Not an array: %s", arrayContext); + + if (groupColumnIO.getChildrenCount() != 1) { + return Optional.empty(); + } + return constructField( + new FieldContext( + ((ArrayType) arrayContext.getType()).getElementType(), + getOnlyElement(arrayContext.getColumnIdentity().getChildren())), + // TODO validate column ID + getArrayElementColumn(groupColumnIO.getChild(0))); + } + + @Override + protected Optional getMapKeyField(FieldContext mapContext, GroupColumnIO groupColumnIO) + { + checkArgument(mapContext.getColumnIdentity().getChildren().size() == 2, "Not a map: %s", mapContext); + + GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); + if (keyValueColumnIO.getChildrenCount() != 2) { + return Optional.empty(); + } + return constructField( + new FieldContext( + ((MapType) mapContext.getType()).getKeyType(), + mapContext.getColumnIdentity().getChildren().get(0)), + // TODO validate column ID + keyValueColumnIO.getChild(0)); + } + + @Override + protected Optional getMapValueField(FieldContext mapContext, GroupColumnIO groupColumnIO) + { + checkArgument(mapContext.getColumnIdentity().getChildren().size() == 2, "Not a map: %s", mapContext); + + GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); + if (keyValueColumnIO.getChildrenCount() != 2) { + return Optional.empty(); + } + return constructField( + new FieldContext( + ((MapType) mapContext.getType()).getValueType(), + mapContext.getColumnIdentity().getChildren().get(1)), + // TODO validate column ID + keyValueColumnIO.getChild(1)); + } + + @Override + protected Optional getRowFieldField(FieldContext rowContext, GroupColumnIO groupColumnIO, int rowFieldIndex) + { + checkArgument(rowFieldIndex < rowContext.getColumnIdentity().getChildren().size(), "Row field out of bounds, or not a row: %s, %s", rowFieldIndex, rowContext); + + FieldContext rowFieldContext = new FieldContext( + ((RowType) rowContext.getType()).getFields().get(rowFieldIndex).getType(), + rowContext.getColumnIdentity().getChildren().get(rowFieldIndex)); + + int fieldId = rowFieldContext.getColumnIdentity().getId(); + for (int i = 0; i < groupColumnIO.getChildrenCount(); i++) { + ColumnIO child = groupColumnIO.getChild(i); + if (child.getType().getId().intValue() == fieldId) { + return constructField(rowFieldContext, child); + } + } + return Optional.empty(); + } + } + + private static class FieldContext + { + private final Type type; + private final ColumnIdentity columnIdentity; + + public FieldContext(Type type, ColumnIdentity columnIdentity) + { + this.type = requireNonNull(type, "type is null"); + this.columnIdentity = requireNonNull(columnIdentity, "columnIdentity is null"); + } + + public Type getType() + { + return type; + } + + public ColumnIdentity getColumnIdentity() + { + return columnIdentity; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("type", type) + .add("columnIdentity", columnIdentity) + .toString(); + } + } } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java index ec454d74edd..b935d57fa7c 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java @@ -562,48 +562,38 @@ public void testIdBasedFieldMapping(StorageFormat storageFormat) row("a_struct", "row(renamed bigint, keep bigint, CaseSensitive bigint, drop_and_add bigint, added bigint)"), row("a_partition", "bigint")); - if (storageFormat == StorageFormat.PARQUET) { - // TODO (https://github.com/trinodb/trino/issues/8750) the results should be the same for all storage formats - - // TODO support Row (JAVA_OBJECT) in Tempto and switch to QueryAssert - Assertions.assertThat(onTrino().executeQuery(format("SELECT quite_renamed_col, keep_col, drop_and_add_col, add_col, casesensitivecol, a_struct, a_partition FROM %s", trinoTableName)).rows()) - .containsOnly(asList( - 2L, // quite_renamed_col - 3L, // keep_col - null, // drop_and_add_col; dropping and re-adding changes id - null, // add_col - 5L, // CaseSensitiveCol - rowBuilder() - // Rename does not change id - .addField("renamed", null) - .addField("keep", 12L) - .addField("CaseSensitive", 14L) - // Dropping and re-adding changes id, so TODO it should be null - .addField("drop_and_add", 13L) - .addField("added", null) - .build(), - 1001L)); - } - else { - // TODO support Row (JAVA_OBJECT) in Tempto and switch to QueryAssert - Assertions.assertThat(onTrino().executeQuery(format("SELECT quite_renamed_col, keep_col, drop_and_add_col, add_col, casesensitivecol, a_struct, a_partition FROM %s", trinoTableName)).rows()) - .containsOnly(asList( - 2L, // quite_renamed_col - 3L, // keep_col - null, // drop_and_add_col; dropping and re-adding changes id - null, // add_col - 5L, // CaseSensitiveCol - rowBuilder() - // Rename does not change id - .addField("renamed", 11L) - .addField("keep", 12L) - .addField("CaseSensitive", 14L) - // Dropping and re-adding changes id - .addField("drop_and_add", null) - .addField("added", null) - .build(), - 1001L)); - } + // TODO support Row (JAVA_OBJECT) in Tempto and switch to QueryAssert + Assertions.assertThat(onTrino().executeQuery(format("SELECT quite_renamed_col, keep_col, drop_and_add_col, add_col, casesensitivecol, a_struct, a_partition FROM %s", trinoTableName)).rows()) + .containsOnly(asList( + 2L, // quite_renamed_col + 3L, // keep_col + null, // drop_and_add_col; dropping and re-adding changes id + null, // add_col + 5L, // CaseSensitiveCol + rowBuilder() + // Rename does not change id + .addField("renamed", 11L) + .addField("keep", 12L) + .addField("CaseSensitive", 14L) + // Dropping and re-adding changes id + .addField("drop_and_add", null) + .addField("added", null) + .build(), + 1001L)); + + // smoke test for dereference + assertThat(onTrino().executeQuery(format("SELECT a_struct.renamed FROM %s", trinoTableName))).containsOnly(row(11L)); + assertThat(onTrino().executeQuery(format("SELECT a_struct.keep FROM %s", trinoTableName))).containsOnly(row(12L)); + assertThat(onTrino().executeQuery(format("SELECT a_struct.casesensitive FROM %s", trinoTableName))).containsOnly(row(14L)); + assertThat(onTrino().executeQuery(format("SELECT a_struct.drop_and_add FROM %s", trinoTableName))).containsOnly(row((Object) null)); + assertThat(onTrino().executeQuery(format("SELECT a_struct.added FROM %s", trinoTableName))).containsOnly(row((Object) null)); + + // smoke test for dereference in a predicate + assertThat(onTrino().executeQuery(format("SELECT keep_col FROM %s WHERE a_struct.renamed = 11", trinoTableName))).containsOnly(row(3L)); + assertThat(onTrino().executeQuery(format("SELECT keep_col FROM %s WHERE a_struct.keep = 12", trinoTableName))).containsOnly(row(3L)); + assertThat(onTrino().executeQuery(format("SELECT keep_col FROM %s WHERE a_struct.casesensitive = 14", trinoTableName))).containsOnly(row(3L)); + assertThat(onTrino().executeQuery(format("SELECT keep_col FROM %s WHERE a_struct.drop_and_add IS NULL", trinoTableName))).containsOnly(row(3L)); + assertThat(onTrino().executeQuery(format("SELECT keep_col FROM %s WHERE a_struct.added IS NULL", trinoTableName))).containsOnly(row(3L)); onSpark().executeQuery("DROP TABLE " + sparkTableName); }