diff --git a/docs/src/main/sphinx/connector/hive.rst b/docs/src/main/sphinx/connector/hive.rst index 439d6fd54c09..a1c6a35b2170 100644 --- a/docs/src/main/sphinx/connector/hive.rst +++ b/docs/src/main/sphinx/connector/hive.rst @@ -484,6 +484,12 @@ with Parquet files performed by the Hive connector. definition. The equivalent catalog session property is ``parquet_use_column_names``. - ``true`` + * - ``parquet.optimized-writer.validation-percentage`` + - Percentage of parquet files to validate after write by re-reading the whole file + when ``parquet.experimental-optimized-writer.enabled`` is set to ``true``. + The equivalent catalog session property is ``parquet_optimized_writer_validation_percentage``. + Validation can be turned off by setting this property to ``0``. + - ``5`` Metastore configuration properties diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ColumnStatisticsValidation.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ColumnStatisticsValidation.java new file mode 100644 index 000000000000..71e29b6c1134 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ColumnStatisticsValidation.java @@ -0,0 +1,177 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.ColumnarArray; +import io.trino.spi.block.ColumnarMap; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.block.ColumnarArray.toColumnarArray; +import static io.trino.spi.block.ColumnarMap.toColumnarMap; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +class ColumnStatisticsValidation +{ + private final Type type; + private final List fieldBuilders; + + private long valuesCount; + private long nonLeafValuesCount; + + public ColumnStatisticsValidation(Type type) + { + this.type = requireNonNull(type, "type is null"); + this.fieldBuilders = type.getTypeParameters().stream() + .map(ColumnStatisticsValidation::new) + .collect(toImmutableList()); + } + + public void addBlock(Block block) + { + addBlock(block, new ColumnStatistics(0, 0)); + } + + public List build() + { + if (fieldBuilders.isEmpty()) { + return ImmutableList.of(new ColumnStatistics(valuesCount, nonLeafValuesCount)); + } + return fieldBuilders.stream() + .flatMap(builder -> builder.build().stream()) + .collect(toImmutableList()); + } + + private void addBlock(Block block, ColumnStatistics columnStatistics) + { + if (fieldBuilders.isEmpty()) { + addPrimitiveBlock(block); + valuesCount += columnStatistics.valuesCount(); + nonLeafValuesCount += columnStatistics.nonLeafValuesCount(); + return; + } + + List fields; + ColumnStatistics mergedColumnStatistics; + if (type instanceof ArrayType) { + ColumnarArray columnarArray = toColumnarArray(block); + fields = ImmutableList.of(columnarArray.getElementsBlock()); + mergedColumnStatistics = columnStatistics.merge(addArrayBlock(columnarArray)); + } + else if (type instanceof MapType) { + ColumnarMap columnarMap = toColumnarMap(block); + fields = ImmutableList.of(columnarMap.getKeysBlock(), columnarMap.getValuesBlock()); + mergedColumnStatistics = columnStatistics.merge(addMapBlock(columnarMap)); + } + else if (type instanceof RowType) { + ColumnarRow columnarRow = ColumnarRow.toColumnarRow(block); + ImmutableList.Builder fieldsBuilder = ImmutableList.builder(); + for (int index = 0; index < columnarRow.getFieldCount(); index++) { + fieldsBuilder.add(columnarRow.getField(index)); + } + fields = fieldsBuilder.build(); + mergedColumnStatistics = columnStatistics.merge(addRowBlock(columnarRow)); + } + else { + throw new TrinoException(NOT_SUPPORTED, format("Unsupported type: %s", type)); + } + + for (int i = 0; i < fieldBuilders.size(); i++) { + fieldBuilders.get(i).addBlock(fields.get(i), mergedColumnStatistics); + } + } + + private void addPrimitiveBlock(Block block) + { + valuesCount += block.getPositionCount(); + if (!block.mayHaveNull()) { + return; + } + int nullsCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nullsCount += block.isNull(position) ? 1 : 0; + } + nonLeafValuesCount += nullsCount; + } + + private static ColumnStatistics addMapBlock(ColumnarMap block) + { + if (!block.mayHaveNull()) { + int emptyEntriesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + emptyEntriesCount += block.getEntryCount(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(emptyEntriesCount, emptyEntriesCount); + } + int nonLeafValuesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nonLeafValuesCount += block.isNull(position) || block.getEntryCount(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(nonLeafValuesCount, nonLeafValuesCount); + } + + private static ColumnStatistics addArrayBlock(ColumnarArray block) + { + if (!block.mayHaveNull()) { + int emptyEntriesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + emptyEntriesCount += block.getLength(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(emptyEntriesCount, emptyEntriesCount); + } + int nonLeafValuesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nonLeafValuesCount += block.isNull(position) || block.getLength(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(nonLeafValuesCount, nonLeafValuesCount); + } + + private static ColumnStatistics addRowBlock(ColumnarRow block) + { + if (!block.mayHaveNull()) { + return new ColumnStatistics(0, 0); + } + int nullsCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nullsCount += block.isNull(position) ? 1 : 0; + } + return new ColumnStatistics(nullsCount, nullsCount); + } + + /** + * @param valuesCount Count of values for a column field, including nulls, empty and defined values. + * @param nonLeafValuesCount Count of non-leaf values for a column field, this is nulls count for primitives + * and count of values below the max definition level for nested types + */ + record ColumnStatistics(long valuesCount, long nonLeafValuesCount) + { + ColumnStatistics merge(ColumnStatistics other) + { + return new ColumnStatistics( + valuesCount + other.valuesCount(), + nonLeafValuesCount + other.nonLeafValuesCount()); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java index 3f3498177c30..719190d336ea 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java @@ -34,4 +34,14 @@ public ParquetCorruptionException(Throwable cause, String messageFormat, Object. { super(format(messageFormat, args), cause); } + + public ParquetCorruptionException(ParquetDataSourceId dataSourceId, String messageFormat, Object... args) + { + super(formatMessage(dataSourceId, messageFormat, args)); + } + + private static String formatMessage(ParquetDataSourceId dataSourceId, String messageFormat, Object[] args) + { + return "Malformed Parquet file. " + format(messageFormat, args) + " [" + dataSourceId + "]"; + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java index 26fd692d1c73..7f82fd3f2be0 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java @@ -13,7 +13,12 @@ */ package io.trino.parquet; +import com.google.common.collect.ImmutableList; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.DecimalType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; import org.apache.parquet.io.ColumnIO; @@ -31,10 +36,14 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static org.apache.parquet.io.ColumnIOUtil.columnDefinitionLevel; +import static org.apache.parquet.io.ColumnIOUtil.columnRepetitionLevel; +import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; public final class ParquetTypeUtils @@ -259,4 +268,50 @@ public static long getShortDecimalValue(byte[] bytes, int startOffset, int lengt return value; } + + public static Optional constructField(Type type, ColumnIO columnIO) + { + if (columnIO == null) { + return Optional.empty(); + } + boolean required = columnIO.getType().getRepetition() != OPTIONAL; + int repetitionLevel = columnRepetitionLevel(columnIO); + int definitionLevel = columnDefinitionLevel(columnIO); + if (type instanceof RowType rowType) { + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + ImmutableList.Builder> fieldsBuilder = ImmutableList.builder(); + List fields = rowType.getFields(); + boolean structHasParameters = false; + for (RowType.Field rowField : fields) { + String name = rowField.getName().orElseThrow().toLowerCase(Locale.ENGLISH); + Optional field = constructField(rowField.getType(), lookupColumnByName(groupColumnIO, name)); + structHasParameters |= field.isPresent(); + fieldsBuilder.add(field); + } + if (structHasParameters) { + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, fieldsBuilder.build())); + } + return Optional.empty(); + } + if (type instanceof MapType mapType) { + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); + if (keyValueColumnIO.getChildrenCount() != 2) { + return Optional.empty(); + } + Optional keyField = constructField(mapType.getKeyType(), keyValueColumnIO.getChild(0)); + Optional valueField = constructField(mapType.getValueType(), keyValueColumnIO.getChild(1)); + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField))); + } + if (type instanceof ArrayType arrayType) { + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + if (groupColumnIO.getChildrenCount() != 1) { + return Optional.empty(); + } + Optional field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0))); + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field))); + } + PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO; + return Optional.of(new PrimitiveField(type, required, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId())); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java index ee56f75aba43..c194405c0029 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java @@ -26,4 +26,12 @@ public static void validateParquet(boolean condition, String formatString, Objec throw new ParquetCorruptionException(format(formatString, args)); } } + + public static void validateParquet(boolean condition, ParquetDataSourceId dataSourceId, String formatString, Object... args) + throws ParquetCorruptionException + { + if (!condition) { + throw new ParquetCorruptionException(dataSourceId, formatString, args); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetWriteValidation.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetWriteValidation.java new file mode 100644 index 000000000000..65eeab7ee484 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetWriteValidation.java @@ -0,0 +1,664 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.SizeOf; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.airlift.slice.XxHash64; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.type.Type; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.format.ColumnChunk; +import org.apache.parquet.format.ColumnMetaData; +import org.apache.parquet.format.RowGroup; +import org.apache.parquet.format.converter.ParquetMetadataConverter; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.internal.hadoop.metadata.IndexReference; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.openjdk.jol.info.ClassLayout; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.slice.SizeOf.SIZE_OF_INT; +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.trino.parquet.ColumnStatisticsValidation.ColumnStatistics; +import static io.trino.parquet.ParquetValidationUtils.validateParquet; +import static io.trino.parquet.ParquetWriteValidation.IndexReferenceValidation.fromIndexReference; +import static java.util.Objects.requireNonNull; + +public class ParquetWriteValidation +{ + private static final ParquetMetadataConverter METADATA_CONVERTER = new ParquetMetadataConverter(); + + private final String createdBy; + private final Optional timeZoneId; + private final List columns; + private final List rowGroups; + private final WriteChecksum checksum; + private final List types; + private final List columnNames; + + private ParquetWriteValidation( + String createdBy, + Optional timeZoneId, + List columns, + List rowGroups, + WriteChecksum checksum, + List types, + List columnNames) + { + this.createdBy = requireNonNull(createdBy, "createdBy is null"); + checkArgument(!createdBy.isEmpty(), "createdBy is empty"); + this.timeZoneId = requireNonNull(timeZoneId, "timeZoneId is null"); + this.columns = requireNonNull(columns, "columnPaths is null"); + this.rowGroups = requireNonNull(rowGroups, "rowGroups is null"); + this.checksum = requireNonNull(checksum, "checksum is null"); + this.types = requireNonNull(types, "types is null"); + this.columnNames = requireNonNull(columnNames, "columnNames is null"); + } + + public String getCreatedBy() + { + return createdBy; + } + + public List getTypes() + { + return types; + } + + public List getColumnNames() + { + return columnNames; + } + + public void validateTimeZone(ParquetDataSourceId dataSourceId, Optional actualTimeZoneId) + throws ParquetCorruptionException + { + validateParquet( + timeZoneId.equals(actualTimeZoneId), + dataSourceId, + "Found unexpected time zone %s, expected %s", + actualTimeZoneId, + timeZoneId); + } + + public void validateColumns(ParquetDataSourceId dataSourceId, MessageType schema) + throws ParquetCorruptionException + { + List actualColumns = schema.getColumns(); + validateParquet( + actualColumns.size() == columns.size(), + dataSourceId, + "Found columns %s, expected %s", + actualColumns, + columns); + for (int columnIndex = 0; columnIndex < columns.size(); columnIndex++) { + validateColumnDescriptorsSame(actualColumns.get(columnIndex), columns.get(columnIndex), dataSourceId); + } + } + + public void validateBlocksMetadata(ParquetDataSourceId dataSourceId, List blocksMetaData) + throws ParquetCorruptionException + { + validateParquet( + blocksMetaData.size() == rowGroups.size(), + dataSourceId, + "Number of row groups %d did not match %d", + blocksMetaData.size(), + rowGroups.size()); + for (int rowGroupIndex = 0; rowGroupIndex < blocksMetaData.size(); rowGroupIndex++) { + BlockMetaData block = blocksMetaData.get(rowGroupIndex); + RowGroup rowGroup = rowGroups.get(rowGroupIndex); + validateParquet( + block.getRowCount() == rowGroup.getNum_rows(), + dataSourceId, + "Number of rows %d in row group %d did not match %d", + block.getRowCount(), + rowGroupIndex, + rowGroup.getNum_rows()); + + List columnChunkMetaData = block.getColumns(); + validateParquet( + columnChunkMetaData.size() == rowGroup.getColumnsSize(), + dataSourceId, + "Number of columns %d in row group %d did not match %d", + columnChunkMetaData.size(), + rowGroupIndex, + rowGroup.getColumnsSize()); + + for (int columnIndex = 0; columnIndex < columnChunkMetaData.size(); columnIndex++) { + ColumnChunkMetaData actualColumnMetadata = columnChunkMetaData.get(columnIndex); + ColumnChunk columnChunk = rowGroup.getColumns().get(columnIndex); + ColumnMetaData expectedColumnMetadata = columnChunk.getMeta_data(); + verifyColumnMetadataMatch( + actualColumnMetadata.getCodec().getParquetCompressionCodec().equals(expectedColumnMetadata.getCodec()), + "Compression codec", + actualColumnMetadata.getCodec(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getCodec()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName().equals(METADATA_CONVERTER.getPrimitive(expectedColumnMetadata.getType())), + "Type", + actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getType()); + + verifyColumnMetadataMatch( + areEncodingsSame(actualColumnMetadata.getEncodings(), expectedColumnMetadata.getEncodings()), + "Encodings", + actualColumnMetadata.getEncodings(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getEncodings()); + + verifyColumnMetadataMatch( + areStatisticsSame(actualColumnMetadata.getStatistics(), expectedColumnMetadata.getStatistics()), + "Statistics", + actualColumnMetadata.getStatistics(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getStatistics()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getFirstDataPageOffset() == expectedColumnMetadata.getData_page_offset(), + "Data page offset", + actualColumnMetadata.getFirstDataPageOffset(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getData_page_offset()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getDictionaryPageOffset() == expectedColumnMetadata.getDictionary_page_offset(), + "Dictionary page offset", + actualColumnMetadata.getDictionaryPageOffset(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getDictionary_page_offset()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getValueCount() == expectedColumnMetadata.getNum_values(), + "Value count", + actualColumnMetadata.getValueCount(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getNum_values()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getTotalUncompressedSize() == expectedColumnMetadata.getTotal_uncompressed_size(), + "Total uncompressed size", + actualColumnMetadata.getTotalUncompressedSize(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getTotal_uncompressed_size()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getTotalSize() == expectedColumnMetadata.getTotal_compressed_size(), + "Total size", + actualColumnMetadata.getTotalSize(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getTotal_compressed_size()); + + IndexReferenceValidation expectedColumnIndexReference = new IndexReferenceValidation(columnChunk.getColumn_index_offset(), columnChunk.getColumn_index_length()); + IndexReference actualColumnIndexReference = actualColumnMetadata.getColumnIndexReference(); + verifyColumnMetadataMatch( + actualColumnIndexReference == null || fromIndexReference(actualColumnMetadata.getColumnIndexReference()).equals(expectedColumnIndexReference), + "Column index reference", + actualColumnIndexReference, + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnIndexReference); + + IndexReferenceValidation expectedOffsetIndexReference = new IndexReferenceValidation(columnChunk.getOffset_index_offset(), columnChunk.getOffset_index_length()); + IndexReference actualOffsetIndexReference = actualColumnMetadata.getOffsetIndexReference(); + verifyColumnMetadataMatch( + actualOffsetIndexReference == null || fromIndexReference(actualOffsetIndexReference).equals(expectedOffsetIndexReference), + "Offset index reference", + actualOffsetIndexReference, + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedOffsetIndexReference); + } + } + } + + public void validateChecksum(ParquetDataSourceId dataSourceId, WriteChecksum actualChecksum) + throws ParquetCorruptionException + { + validateParquet( + checksum.totalRowCount() == actualChecksum.totalRowCount(), + dataSourceId, + "Write validation failed: Expected row count %d, found %d", + checksum.totalRowCount(), + actualChecksum.totalRowCount()); + + List columnHashes = actualChecksum.columnHashes(); + for (int columnIndex = 0; columnIndex < columnHashes.size(); columnIndex++) { + long expectedHash = checksum.columnHashes().get(columnIndex); + validateParquet( + expectedHash == columnHashes.get(columnIndex), + dataSourceId, + "Invalid checksum for column %s: Expected hash %d, found %d", + columnIndex, + expectedHash, + columnHashes.get(columnIndex)); + } + } + + public record WriteChecksum(long totalRowCount, List columnHashes) + { + public WriteChecksum(long totalRowCount, List columnHashes) + { + this.totalRowCount = totalRowCount; + this.columnHashes = ImmutableList.copyOf(requireNonNull(columnHashes, "columnHashes is null")); + } + } + + public static class WriteChecksumBuilder + { + private final List validationHashes; + private final List columnHashes; + private final byte[] longBuffer = new byte[Long.BYTES]; + private final Slice longSlice = Slices.wrappedBuffer(longBuffer); + + private long totalRowCount; + + private WriteChecksumBuilder(List types) + { + this.validationHashes = requireNonNull(types, "types is null").stream() + .map(ValidationHash::createValidationHash) + .collect(toImmutableList()); + + ImmutableList.Builder columnHashes = ImmutableList.builder(); + for (Type ignored : types) { + columnHashes.add(new XxHash64()); + } + this.columnHashes = columnHashes.build(); + } + + public static WriteChecksumBuilder createWriteChecksumBuilder(List readTypes) + { + return new WriteChecksumBuilder(readTypes); + } + + public void addPage(Page page) + { + requireNonNull(page, "page is null"); + checkArgument( + page.getChannelCount() == columnHashes.size(), + "Invalid page: page channels count %s did not match columns count %s", + page.getChannelCount(), + columnHashes.size()); + + for (int channel = 0; channel < columnHashes.size(); channel++) { + ValidationHash validationHash = validationHashes.get(channel); + Block block = page.getBlock(channel); + XxHash64 xxHash64 = columnHashes.get(channel); + for (int position = 0; position < block.getPositionCount(); position++) { + long hash = validationHash.hash(block, position); + longSlice.setLong(0, hash); + xxHash64.update(longBuffer); + } + } + totalRowCount += page.getPositionCount(); + } + + public WriteChecksum build() + { + return new WriteChecksum( + totalRowCount, + columnHashes.stream() + .map(XxHash64::hash) + .collect(toImmutableList())); + } + } + + public void validateRowGroupStatistics(ParquetDataSourceId dataSourceId, BlockMetaData blockMetaData, List actualColumnStatistics) + throws ParquetCorruptionException + { + List columnChunks = blockMetaData.getColumns(); + checkArgument( + columnChunks.size() == actualColumnStatistics.size(), + "Column chunk metadata count %s did not match column fields count %s", + columnChunks.size(), + actualColumnStatistics.size()); + + for (int columnIndex = 0; columnIndex < columnChunks.size(); columnIndex++) { + ColumnChunkMetaData columnMetaData = columnChunks.get(columnIndex); + ColumnStatistics columnStatistics = actualColumnStatistics.get(columnIndex); + long expectedValuesCount = columnMetaData.getValueCount(); + validateParquet( + expectedValuesCount == columnStatistics.valuesCount(), + dataSourceId, + "Invalid values count for column %s: Expected %d, found %d", + columnIndex, + expectedValuesCount, + columnStatistics.valuesCount()); + + Statistics parquetStatistics = columnMetaData.getStatistics(); + if (parquetStatistics.isNumNullsSet()) { + long expectedNullsCount = parquetStatistics.getNumNulls(); + validateParquet( + expectedNullsCount == columnStatistics.nonLeafValuesCount(), + dataSourceId, + "Invalid nulls count for column %s: Expected %d, found %d", + columnIndex, + expectedNullsCount, + columnStatistics.nonLeafValuesCount()); + } + } + } + + public static class StatisticsValidation + { + private final List types; + private List columnStatisticsValidations; + + private StatisticsValidation(List types) + { + this.types = requireNonNull(types, "types is null"); + this.columnStatisticsValidations = types.stream() + .map(ColumnStatisticsValidation::new) + .collect(toImmutableList()); + } + + public static StatisticsValidation createStatisticsValidationBuilder(List readTypes) + { + return new StatisticsValidation(readTypes); + } + + public void addPage(Page page) + { + requireNonNull(page, "page is null"); + checkArgument( + page.getChannelCount() == columnStatisticsValidations.size(), + "Invalid page: page channels count %s did not match columns count %s", + page.getChannelCount(), + columnStatisticsValidations.size()); + + for (int channel = 0; channel < columnStatisticsValidations.size(); channel++) { + ColumnStatisticsValidation columnStatisticsValidation = columnStatisticsValidations.get(channel); + columnStatisticsValidation.addBlock(page.getBlock(channel)); + } + } + + public void reset() + { + this.columnStatisticsValidations = types.stream() + .map(ColumnStatisticsValidation::new) + .collect(toImmutableList()); + } + + public List build() + { + return this.columnStatisticsValidations.stream() + .flatMap(validation -> validation.build().stream()) + .collect(toImmutableList()); + } + } + + public static class ParquetWriteValidationBuilder + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ParquetWriteValidationBuilder.class).instanceSize(); + private static final int COLUMN_DESCRIPTOR_INSTANCE_SIZE = ClassLayout.parseClass(ColumnDescriptor.class).instanceSize(); + private static final int PRIMITIVE_TYPE_INSTANCE_SIZE = ClassLayout.parseClass(PrimitiveType.class).instanceSize(); + + private final List types; + private final List columnNames; + private final WriteChecksumBuilder checksum; + + private String createdBy; + private Optional timeZoneId = Optional.empty(); + private List columns; + private List rowGroups; + private long retainedSize = INSTANCE_SIZE; + + public ParquetWriteValidationBuilder(List types, List columnNames) + { + this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); + this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null")); + checkArgument( + types.size() == columnNames.size(), + "Types count %s did not match column names count %s", + types.size(), + columnNames.size()); + this.checksum = new WriteChecksumBuilder(types); + retainedSize += estimatedSizeOf(types, type -> 0) + + estimatedSizeOf(columnNames, SizeOf::estimatedSizeOf); + } + + public long getRetainedSize() + { + return retainedSize; + } + + public void setCreatedBy(String createdBy) + { + this.createdBy = createdBy; + retainedSize += estimatedSizeOf(createdBy); + } + + public void setTimeZone(Optional timeZoneId) + { + this.timeZoneId = timeZoneId; + timeZoneId.ifPresent(id -> retainedSize += estimatedSizeOf(id)); + } + + public void setColumns(List columns) + { + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + retainedSize += estimatedSizeOf(columns, descriptor -> { + return COLUMN_DESCRIPTOR_INSTANCE_SIZE + + (2 * SIZE_OF_INT) // maxRep, maxDef + + estimatedSizeOfStringArray(descriptor.getPath()) + + PRIMITIVE_TYPE_INSTANCE_SIZE + + (3 * SIZE_OF_INT); // primitive, length, columnOrder + }); + } + + public void setRowGroups(List rowGroups) + { + this.rowGroups = ImmutableList.copyOf(requireNonNull(rowGroups, "rowGroups is null")); + } + + public void addPage(Page page) + { + checksum.addPage(page); + } + + public ParquetWriteValidation build() + { + return new ParquetWriteValidation( + createdBy, + timeZoneId, + columns, + rowGroups, + checksum.build(), + types, + columnNames); + } + } + + // parquet-mr IndexReference class lacks equals and toString implementations + static class IndexReferenceValidation + { + private final long offset; + private final int length; + + private IndexReferenceValidation(long offset, int length) + { + this.offset = offset; + this.length = length; + } + + static IndexReferenceValidation fromIndexReference(IndexReference indexReference) + { + return new IndexReferenceValidation(indexReference.getOffset(), indexReference.getLength()); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IndexReferenceValidation that = (IndexReferenceValidation) o; + return offset == that.offset && length == that.length; + } + + @Override + public int hashCode() + { + return Objects.hash(offset, length); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("offset", offset) + .add("length", length) + .toString(); + } + } + + private static void verifyColumnMetadataMatch( + boolean condition, + String name, + T actual, + ColumnPath path, + int rowGroup, + ParquetDataSourceId dataSourceId, + U expected) + throws ParquetCorruptionException + { + if (!condition) { + throw new ParquetCorruptionException( + dataSourceId, + "%s [%s] for column %s in row group %d did not match [%s]", + name, + actual, + path, + rowGroup, + expected); + } + } + + private static boolean areEncodingsSame(Set actual, List expected) + { + return actual.equals(expected.stream().map(METADATA_CONVERTER::getEncoding).collect(toImmutableSet())); + } + + private static boolean areStatisticsSame(org.apache.parquet.column.statistics.Statistics actual, org.apache.parquet.format.Statistics expected) + { + Statistics.Builder expectedStatsBuilder = Statistics.getBuilderForReading(actual.type()); + if (expected.isSetNull_count()) { + expectedStatsBuilder.withNumNulls(expected.getNull_count()); + } + if (expected.isSetMin_value()) { + expectedStatsBuilder.withMin(expected.getMin_value()); + } + if (expected.isSetMax_value()) { + expectedStatsBuilder.withMax(expected.getMax_value()); + } + return actual.equals(expectedStatsBuilder.build()); + } + + private static void validateColumnDescriptorsSame(ColumnDescriptor actual, ColumnDescriptor expected, ParquetDataSourceId dataSourceId) + throws ParquetCorruptionException + { + // Column names are lower-cased by MetadataReader#readFooter + validateParquet( + Arrays.equals(actual.getPath(), Arrays.stream(expected.getPath()).map(field -> field.toLowerCase(Locale.ENGLISH)).toArray()), + dataSourceId, + "Column path %s did not match expected column path %s", + actual.getPath(), + expected.getPath()); + + validateParquet( + actual.getMaxDefinitionLevel() == expected.getMaxDefinitionLevel(), + dataSourceId, + "Column %s max definition level %d did not match expected max definition level %d", + actual.getPath(), + actual.getMaxDefinitionLevel(), + expected.getMaxDefinitionLevel()); + + validateParquet( + actual.getMaxRepetitionLevel() == expected.getMaxRepetitionLevel(), + dataSourceId, + "Column %s max repetition level %d did not match expected max repetition level %d", + actual.getPath(), + actual.getMaxRepetitionLevel(), + expected.getMaxRepetitionLevel()); + + PrimitiveType actualPrimitiveType = actual.getPrimitiveType(); + PrimitiveType expectedPrimitiveType = expected.getPrimitiveType(); + // We don't use PrimitiveType#equals directly because column names are lower-cased by MetadataReader#readFooter + validateParquet( + actualPrimitiveType.getPrimitiveTypeName().equals(expectedPrimitiveType.getPrimitiveTypeName()) + && actualPrimitiveType.getTypeLength() == expectedPrimitiveType.getTypeLength() + && actualPrimitiveType.getRepetition().equals(expectedPrimitiveType.getRepetition()) + && actualPrimitiveType.getName().equals(expectedPrimitiveType.getName().toLowerCase(Locale.ENGLISH)) + && Objects.equals(actualPrimitiveType.getLogicalTypeAnnotation(), expectedPrimitiveType.getLogicalTypeAnnotation()), + dataSourceId, + "Column %s primitive type %s did not match expected primitive type %s", + actual.getPath(), + actualPrimitiveType, + expectedPrimitiveType); + } + + private static long estimatedSizeOfStringArray(String[] path) + { + long size = sizeOf(path); + for (String field : path) { + size += estimatedSizeOf(field); + } + return size; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ValidationHash.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ValidationHash.java new file mode 100644 index 000000000000..86cb0ceceafa --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ValidationHash.java @@ -0,0 +1,144 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import io.trino.spi.block.Block; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodType; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.type.StandardTypes.ARRAY; +import static io.trino.spi.type.StandardTypes.MAP; +import static io.trino.spi.type.StandardTypes.ROW; +import static java.lang.invoke.MethodHandles.lookup; +import static java.util.Objects.requireNonNull; + +/** + * Based on io.trino.rcfile.ValidationHash and io.trino.orc.ValidationHash + * with minor differences in handling of timestamp and map types. + */ +class ValidationHash +{ + // This value is a large arbitrary prime + private static final long NULL_HASH_CODE = 0x6e3efbd56c16a0cbL; + + private static final MethodHandle MAP_HASH; + private static final MethodHandle ARRAY_HASH; + private static final MethodHandle ROW_HASH; + + static { + try { + MAP_HASH = lookup().findStatic( + ValidationHash.class, + "mapHash", + MethodType.methodType(long.class, Type.class, ValidationHash.class, ValidationHash.class, Block.class, int.class)); + ARRAY_HASH = lookup().findStatic( + ValidationHash.class, + "arrayHash", + MethodType.methodType(long.class, Type.class, ValidationHash.class, Block.class, int.class)); + ROW_HASH = lookup().findStatic( + ValidationHash.class, + "rowHash", + MethodType.methodType(long.class, Type.class, ValidationHash[].class, Block.class, int.class)); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + // This should really come from the environment, but there is not good way to get a value here + private static final TypeOperators VALIDATION_TYPE_OPERATORS_CACHE = new TypeOperators(); + + public static ValidationHash createValidationHash(Type type) + { + requireNonNull(type, "type is null"); + if (type.getTypeSignature().getBase().equals(MAP)) { + ValidationHash keyHash = createValidationHash(type.getTypeParameters().get(0)); + ValidationHash valueHash = createValidationHash(type.getTypeParameters().get(1)); + return new ValidationHash(MAP_HASH.bindTo(type).bindTo(keyHash).bindTo(valueHash)); + } + + if (type.getTypeSignature().getBase().equals(ARRAY)) { + ValidationHash elementHash = createValidationHash(type.getTypeParameters().get(0)); + return new ValidationHash(ARRAY_HASH.bindTo(type).bindTo(elementHash)); + } + + if (type.getTypeSignature().getBase().equals(ROW)) { + ValidationHash[] fieldHashes = type.getTypeParameters().stream() + .map(ValidationHash::createValidationHash) + .toArray(ValidationHash[]::new); + return new ValidationHash(ROW_HASH.bindTo(type).bindTo(fieldHashes)); + } + + return new ValidationHash(VALIDATION_TYPE_OPERATORS_CACHE.getHashCodeOperator(type, InvocationConvention.simpleConvention(FAIL_ON_NULL, BLOCK_POSITION))); + } + + private final MethodHandle hashCodeOperator; + + private ValidationHash(MethodHandle hashCodeOperator) + { + this.hashCodeOperator = requireNonNull(hashCodeOperator, "hashCodeOperator is null"); + } + + public long hash(Block block, int position) + { + if (block.isNull(position)) { + return NULL_HASH_CODE; + } + try { + return (long) hashCodeOperator.invokeExact(block, position); + } + catch (Throwable throwable) { + throwIfUnchecked(throwable); + throw new RuntimeException(throwable); + } + } + + private static long mapHash(Type type, ValidationHash keyHash, ValidationHash valueHash, Block block, int position) + { + Block mapBlock = (Block) type.getObject(block, position); + long hash = 0; + for (int i = 0; i < mapBlock.getPositionCount(); i += 2) { + hash = 31 * hash + keyHash.hash(mapBlock, i); + hash = 31 * hash + valueHash.hash(mapBlock, i + 1); + } + return hash; + } + + private static long arrayHash(Type type, ValidationHash elementHash, Block block, int position) + { + Block array = (Block) type.getObject(block, position); + long hash = 0; + for (int i = 0; i < array.getPositionCount(); i++) { + hash = 31 * hash + elementHash.hash(array, i); + } + return hash; + } + + private static long rowHash(Type type, ValidationHash[] fieldHashes, Block block, int position) + { + Block row = (Block) type.getObject(block, position); + long hash = 0; + for (int i = 0; i < row.getPositionCount(); i++) { + hash = 31 * hash + fieldHashes[i].hash(row, i); + } + return hash; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java index ad267e7c8375..4dc824c16c12 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java @@ -16,7 +16,10 @@ import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetWriteValidation; import org.apache.parquet.CorruptStatistics; import org.apache.parquet.column.statistics.BinaryStatistics; import org.apache.parquet.format.ColumnChunk; @@ -61,6 +64,7 @@ import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static org.apache.hadoop.hive.ql.io.parquet.write.DataWritableWriteSupport.WRITER_TIMEZONE; import static org.apache.parquet.format.Util.readFileMetaData; import static org.apache.parquet.format.converter.ParquetMetadataConverterUtil.getLogicalTypeAnnotation; @@ -75,7 +79,7 @@ public final class MetadataReader private MetadataReader() {} - public static ParquetMetadata readFooter(ParquetDataSource dataSource) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation) throws IOException { // Parquet File Layout: @@ -165,7 +169,12 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource) keyValueMetaData.put(keyValue.key, keyValue.value); } } - return new ParquetMetadata(new org.apache.parquet.hadoop.metadata.FileMetaData(messageType, keyValueMetaData, fileMetaData.getCreated_by()), blocks); + org.apache.parquet.hadoop.metadata.FileMetaData parquetFileMetadata = new org.apache.parquet.hadoop.metadata.FileMetaData( + messageType, + keyValueMetaData, + fileMetaData.getCreated_by()); + validateFileMetadata(dataSource.getId(), parquetFileMetadata, parquetWriteValidation); + return new ParquetMetadata(parquetFileMetadata, blocks); } private static MessageType readParquetSchema(List schema) @@ -380,4 +389,17 @@ private static IndexReference toOffsetIndexReference(ColumnChunk columnChunk) } return null; } + + private static void validateFileMetadata(ParquetDataSourceId dataSourceId, org.apache.parquet.hadoop.metadata.FileMetaData fileMetaData, Optional parquetWriteValidation) + throws ParquetCorruptionException + { + if (parquetWriteValidation.isEmpty()) { + return; + } + ParquetWriteValidation writeValidation = parquetWriteValidation.get(); + writeValidation.validateTimeZone( + dataSourceId, + Optional.ofNullable(fileMetaData.getKeyValueMetaData().get(WRITER_TIMEZONE))); + writeValidation.validateColumns(dataSourceId, fileMetaData.getSchema()); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetBlockFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetBlockFactory.java new file mode 100644 index 000000000000..241c3db01bd6 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetBlockFactory.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.spi.block.Block; +import io.trino.spi.block.LazyBlock; +import io.trino.spi.block.LazyBlockLoader; + +import java.io.IOException; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class ParquetBlockFactory +{ + private final Function exceptionTransform; + private int currentPageId; + + public ParquetBlockFactory(Function exceptionTransform) + { + this.exceptionTransform = requireNonNull(exceptionTransform, "exceptionTransform is null"); + } + + public void nextPage() + { + currentPageId++; + } + + public Block createBlock(int positionCount, ParquetBlockReader reader) + { + return new LazyBlock(positionCount, new ParquetBlockLoader(reader)); + } + + public interface ParquetBlockReader + { + Block readBlock() + throws IOException; + } + + private final class ParquetBlockLoader + implements LazyBlockLoader + { + private final int expectedPageId = currentPageId; + private final ParquetBlockReader blockReader; + private boolean loaded; + + public ParquetBlockLoader(ParquetBlockReader blockReader) + { + this.blockReader = requireNonNull(blockReader, "blockReader is null"); + } + + @Override + public Block load() + { + checkState(!loaded, "Already loaded"); + checkState(currentPageId == expectedPageId, "Parquet reader has been advanced beyond block"); + + loaded = true; + try { + return blockReader.readBlock(); + } + catch (IOException | RuntimeException e) { + throw exceptionTransform.apply(e); + } + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 582fd4ca8d6e..d653937a4247 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -29,10 +29,12 @@ import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.PrimitiveField; import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.FilteredOffsetIndex.OffsetRange; import io.trino.plugin.base.metrics.LongCount; +import io.trino.spi.Page; import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.RowBlock; @@ -67,11 +69,16 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static io.trino.parquet.ParquetValidationUtils.validateParquet; +import static io.trino.parquet.ParquetWriteValidation.StatisticsValidation; +import static io.trino.parquet.ParquetWriteValidation.StatisticsValidation.createStatisticsValidationBuilder; +import static io.trino.parquet.ParquetWriteValidation.WriteChecksumBuilder; +import static io.trino.parquet.ParquetWriteValidation.WriteChecksumBuilder.createWriteChecksumBuilder; import static io.trino.parquet.reader.ListColumnReader.calculateCollectionOffsets; import static java.lang.Math.max; import static java.lang.Math.min; @@ -90,7 +97,8 @@ public class ParquetReader private final Optional fileCreatedBy; private final List blocks; private final List firstRowsOfBlocks; - private final List fields; + private final List columnFields; + private final List primitiveFields; private final ParquetDataSource dataSource; private final DateTimeZone timeZone; private final AggregatedMemoryContext memoryContext; @@ -119,41 +127,48 @@ public class ParquetReader private AggregatedMemoryContext currentRowGroupMemoryContext; private final Multimap chunkReaders; private final List> columnIndexStore; + private final Optional writeValidation; + private final Optional writeChecksumBuilder; + private final Optional rowGroupStatisticsValidation; private final List blockRowRanges; private final Map paths = new HashMap<>(); + private final ParquetBlockFactory blockFactory; private final Map> codecMetrics; public ParquetReader( Optional fileCreatedBy, - List> fields, + List columnFields, List blocks, List firstRowsOfBlocks, ParquetDataSource dataSource, DateTimeZone timeZone, AggregatedMemoryContext memoryContext, ParquetReaderOptions options, - Optional parquetPredicate) + Function exceptionTransform) throws IOException { - this(fileCreatedBy, fields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, parquetPredicate, nCopies(blocks.size(), Optional.empty())); + this(fileCreatedBy, columnFields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, exceptionTransform, Optional.empty(), nCopies(blocks.size(), Optional.empty()), Optional.empty()); } public ParquetReader( Optional fileCreatedBy, - List> fields, + List columnFields, List blocks, List firstRowsOfBlocks, ParquetDataSource dataSource, DateTimeZone timeZone, AggregatedMemoryContext memoryContext, ParquetReaderOptions options, + Function exceptionTransform, Optional parquetPredicate, - List> columnIndexStore) + List> columnIndexStore, + Optional writeValidation) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); - List primitiveFields = getPrimitiveFields(requireNonNull(fields, "fields is null")); - this.fields = primitiveFields; + requireNonNull(columnFields, "columnFields is null"); + this.columnFields = ImmutableList.copyOf(columnFields); + this.primitiveFields = getPrimitiveFields(columnFields); this.blocks = requireNonNull(blocks, "blocks is null"); this.firstRowsOfBlocks = requireNonNull(firstRowsOfBlocks, "firstRowsOfBlocks is null"); this.dataSource = requireNonNull(dataSource, "dataSource is null"); @@ -166,6 +181,16 @@ public ParquetReader( checkArgument(blocks.size() == firstRowsOfBlocks.size(), "elements of firstRowsOfBlocks must correspond to blocks"); + this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); + validateWrite( + validation -> fileCreatedBy.equals(Optional.of(validation.getCreatedBy())), + "Expected created by %s, found %s", + writeValidation.map(ParquetWriteValidation::getCreatedBy), + fileCreatedBy); + validateBlockMetadata(blocks); + this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(validation.getTypes())); + this.rowGroupStatisticsValidation = writeValidation.map(validation -> createStatisticsValidationBuilder(validation.getTypes())); + this.blockRowRanges = listWithNulls(this.blocks.size()); for (PrimitiveField field : primitiveFields) { ColumnDescriptor columnDescriptor = field.getDescriptor(); @@ -180,6 +205,7 @@ public ParquetReader( else { this.filter = Optional.empty(); } + this.blockFactory = new ParquetBlockFactory(exceptionTransform); ListMultimap ranges = ArrayListMultimap.create(); Map codecMetrics = new HashMap<>(); for (int rowGroup = 0; rowGroup < blocks.size(); rowGroup++) { @@ -224,6 +250,30 @@ public void close() freeCurrentRowGroupBuffers(); currentRowGroupMemoryContext.close(); dataSource.close(); + + if (writeChecksumBuilder.isPresent()) { + ParquetWriteValidation parquetWriteValidation = writeValidation.orElseThrow(); + parquetWriteValidation.validateChecksum(dataSource.getId(), writeChecksumBuilder.get().build()); + } + } + + public Page nextPage() + throws IOException + { + int batchSize = nextBatch(); + if (batchSize <= 0) { + return null; + } + // create a lazy page + blockFactory.nextPage(); + Block[] blocks = new Block[columnFields.size()]; + for (int channel = 0; channel < columnFields.size(); channel++) { + Field field = columnFields.get(channel); + blocks[channel] = blockFactory.createBlock(batchSize, () -> readBlock(field)); + } + Page page = new Page(batchSize, blocks); + validateWritePageChecksum(page); + return page; } /** @@ -234,7 +284,8 @@ public long lastBatchStartRow() return firstRowIndexInGroup + nextRowInGroup - batchSize; } - public int nextBatch() + private int nextBatch() + throws IOException { if (nextRowInGroup >= currentGroupRowCount && !advanceToNextRowGroup()) { return -1; @@ -250,10 +301,18 @@ public int nextBatch() } private boolean advanceToNextRowGroup() + throws IOException { currentRowGroupMemoryContext.close(); currentRowGroupMemoryContext = memoryContext.newAggregatedMemoryContext(); freeCurrentRowGroupBuffers(); + + if (currentRowGroup >= 0 && rowGroupStatisticsValidation.isPresent()) { + StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.get(); + writeValidation.orElseThrow().validateRowGroupStatistics(dataSource.getId(), currentBlockMetadata, statisticsValidation.build()); + statisticsValidation.reset(); + } + currentRowGroup++; if (currentRowGroup == blocks.size()) { return false; @@ -282,7 +341,7 @@ private void freeCurrentRowGroupBuffers() return; } - for (int column = 0; column < fields.size(); column++) { + for (int column = 0; column < primitiveFields.size(); column++) { Collection readers = chunkReaders.get(new ChunkKey(column, currentRowGroup)); if (readers != null) { for (ChunkReader reader : readers) { @@ -433,18 +492,15 @@ private ColumnChunkMetaData getColumnChunkMetaData(BlockMetaData blockMetaData, private void initializeColumnReaders() { - for (PrimitiveField field : fields) { + for (PrimitiveField field : primitiveFields) { columnReaders.put(field.getId(), PrimitiveColumnReader.createReader(field, timeZone)); } } - public static List getPrimitiveFields(List> fields) + public static List getPrimitiveFields(List fields) { Map primitiveFields = new HashMap<>(); - - fields.stream() - .flatMap(Optional::stream) - .forEach(field -> parseField(field, primitiveFields)); + fields.forEach(field -> parseField(field, primitiveFields)); return ImmutableList.copyOf(primitiveFields.values()); } @@ -521,4 +577,29 @@ private RowRanges getRowRanges(FilterPredicate filter, int blockIndex) } return rowRanges; } + + private void validateWritePageChecksum(Page page) + { + if (writeChecksumBuilder.isPresent()) { + page = page.getLoadedPage(); + writeChecksumBuilder.get().addPage(page); + rowGroupStatisticsValidation.orElseThrow().addPage(page); + } + } + + private void validateBlockMetadata(List blockMetaData) + throws ParquetCorruptionException + { + if (writeValidation.isPresent()) { + writeValidation.get().validateBlocksMetadata(dataSource.getId(), blockMetaData); + } + } + + private void validateWrite(java.util.function.Predicate test, String messageFormat, Object... args) + throws ParquetCorruptionException + { + if (writeValidation.isPresent() && !test.test(writeValidation.get())) { + throw new ParquetCorruptionException(dataSource.getId(), "Write validation failed: " + messageFormat, args); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReaderColumn.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReaderColumn.java new file mode 100644 index 000000000000..a2060e273711 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReaderColumn.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.parquet.Field; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * @param type Column type + * @param field Field description. Empty optional will result in column populated with {@code NULL} + * @param isRowIndexColumn Whether column should be populated with the indices of its rows + */ +public record ParquetReaderColumn(Type type, Optional field, boolean isRowIndexColumn) +{ + public static List getParquetReaderFields(List parquetReaderColumns) + { + return parquetReaderColumns.stream() + .filter(column -> !column.isRowIndexColumn()) + .map(ParquetReaderColumn::field) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableList()); + } + + public ParquetReaderColumn(Type type, Optional field, boolean isRowIndexColumn) + { + this.type = requireNonNull(type, "type is null"); + this.field = requireNonNull(field, "field is null"); + checkArgument( + !isRowIndexColumn || field.isEmpty(), + "Field info for row index column must be empty Optional"); + this.isRowIndexColumn = isRowIndexColumn; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 950f1fa4c8a8..934191dab0a2 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -20,6 +20,13 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; +import io.trino.parquet.Field; +import io.trino.parquet.ParquetCorruptionException; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.reader.MetadataReader; +import io.trino.parquet.reader.ParquetReader; import io.trino.parquet.writer.ColumnWriter.BufferData; import io.trino.spi.Page; import io.trino.spi.type.Type; @@ -29,7 +36,10 @@ import org.apache.parquet.format.KeyValue; import org.apache.parquet.format.RowGroup; import org.apache.parquet.format.Util; +import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.MessageType; import org.joda.time.DateTimeZone; import org.openjdk.jol.info.ClassLayout; @@ -40,19 +50,27 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.Slices.wrappedBuffer; import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.trino.parquet.ParquetTypeUtils.constructField; +import static io.trino.parquet.ParquetTypeUtils.getColumnIO; +import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; +import static io.trino.parquet.ParquetWriteValidation.ParquetWriteValidationBuilder; import static io.trino.parquet.writer.ParquetDataOutput.createDataOutput; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; import static org.apache.hadoop.hive.ql.io.parquet.write.DataWritableWriteSupport.WRITER_TIMEZONE; import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0; @@ -74,6 +92,7 @@ public class ParquetWriter private final Optional parquetTimeZone; private final ImmutableList.Builder rowGroupBuilder = ImmutableList.builder(); + private final Optional validationBuilder; private List columnWriters; private int rows; @@ -90,17 +109,23 @@ public ParquetWriter( ParquetWriterOptions writerOption, CompressionCodecName compressionCodecName, String trinoVersion, - Optional parquetTimeZone) + Optional parquetTimeZone, + Optional validationBuilder) { + this.validationBuilder = requireNonNull(validationBuilder, "validationBuilder is null"); this.outputStream = new OutputStreamSliceOutput(requireNonNull(outputStream, "outputstream is null")); this.messageType = requireNonNull(messageType, "messageType is null"); this.primitiveTypes = requireNonNull(primitiveTypes, "primitiveTypes is null"); this.writerOption = requireNonNull(writerOption, "writerOption is null"); this.compressionCodecName = requireNonNull(compressionCodecName, "compressionCodecName is null"); this.parquetTimeZone = requireNonNull(parquetTimeZone, "parquetTimeZone is null"); + this.createdBy = formatCreatedBy(requireNonNull(trinoVersion, "trinoVersion is null")); + + recordValidation(validation -> validation.setTimeZone(parquetTimeZone.map(DateTimeZone::getID))); + recordValidation(validation -> validation.setColumns(messageType.getColumns())); + recordValidation(validation -> validation.setCreatedBy(createdBy)); initColumnWriters(); this.chunkMaxLogicalBytes = max(1, CHUNK_MAX_BYTES / 2); - this.createdBy = formatCreatedBy(requireNonNull(trinoVersion, "trinoVersion is null")); } public long getWrittenBytes() @@ -117,7 +142,8 @@ public long getRetainedBytes() { return INSTANCE_SIZE + outputStream.getRetainedSize() + - columnWriters.stream().mapToLong(ColumnWriter::getRetainedBytes).sum(); + columnWriters.stream().mapToLong(ColumnWriter::getRetainedBytes).sum() + + validationBuilder.map(ParquetWriteValidationBuilder::getRetainedSize).orElse(0L); } public void write(Page page) @@ -131,6 +157,9 @@ public void write(Page page) checkArgument(page.getChannelCount() == columnWriters.size()); + Page validationPage = page; + recordValidation(validation -> validation.addPage(validationPage)); + while (page != null) { int chunkRows = min(page.getPositionCount(), writerOption.getBatchSize()); Page chunk = page.getRegion(0, chunkRows); @@ -190,6 +219,70 @@ public void close() bufferedBytes = 0; } + public void validate(ParquetDataSource input) + throws ParquetCorruptionException + { + checkState(validationBuilder.isPresent(), "validation is not enabled"); + ParquetWriteValidation writeValidation = validationBuilder.get().build(); + try { + ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation)); + try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { + for (Page page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { + // fully load the page + page.getLoadedPage(); + } + } + } + catch (IOException e) { + if (e instanceof ParquetCorruptionException) { + throw (ParquetCorruptionException) e; + } + throw new ParquetCorruptionException(input.getId(), "Validation failed with exception %s", e); + } + } + + private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetadata parquetMetadata, ParquetWriteValidation writeValidation) + throws IOException + { + org.apache.parquet.hadoop.metadata.FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); + MessageColumnIO messageColumnIO = getColumnIO(fileMetaData.getSchema(), fileMetaData.getSchema()); + ImmutableList.Builder columnFields = ImmutableList.builder(); + for (int i = 0; i < writeValidation.getTypes().size(); i++) { + columnFields.add(constructField( + writeValidation.getTypes().get(i), + lookupColumnByName(messageColumnIO, writeValidation.getColumnNames().get(i))) + .orElseThrow()); + } + long nextStart = 0; + ImmutableList.Builder blockStartsBuilder = ImmutableList.builder(); + for (BlockMetaData block : parquetMetadata.getBlocks()) { + blockStartsBuilder.add(nextStart); + nextStart += block.getRowCount(); + } + List blockStarts = blockStartsBuilder.build(); + return new ParquetReader( + Optional.ofNullable(fileMetaData.getCreatedBy()), + columnFields.build(), + parquetMetadata.getBlocks(), + blockStarts, + input, + parquetTimeZone.orElseThrow(), + newSimpleAggregatedMemoryContext(), + new ParquetReaderOptions(), + exception -> { + throwIfUnchecked(exception); + return new RuntimeException(exception); + }, + Optional.empty(), + nCopies(blockStarts.size(), Optional.empty()), + Optional.of(writeValidation)); + } + + private void recordValidation(Consumer task) + { + validationBuilder.ifPresent(task); + } + // Parquet File Layout: // // MAGIC @@ -239,7 +332,9 @@ private void writeFooter() throws IOException { checkState(closed); - Slice footer = getFooter(rowGroupBuilder.build(), messageType); + List rowGroups = rowGroupBuilder.build(); + Slice footer = getFooter(rowGroups, messageType); + recordValidation(validation -> validation.setRowGroups(rowGroups)); createDataOutput(footer).writeData(outputStream); Slice footerSize = Slices.allocate(SIZE_OF_INT); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index f9a0ce2a5f2d..cc2ce8f27f68 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -225,11 +225,12 @@ private FileWriter createParquetFileWriter(FileSystem fileSystem, Path path, Lis }) .collect(toImmutableList()); + List dataColumnNames = dataColumns.stream() + .map(DeltaLakeColumnHandle::getName) + .collect(toImmutableList()); ParquetSchemaConverter schemaConverter = new ParquetSchemaConverter( parquetTypes, - dataColumns.stream() - .map(DeltaLakeColumnHandle::getName) - .collect(toImmutableList()), + dataColumnNames, false, false); @@ -237,12 +238,14 @@ private FileWriter createParquetFileWriter(FileSystem fileSystem, Path path, Lis fileSystem.create(path), rollbackAction, parquetTypes, + dataColumnNames, schemaConverter.getMessageType(), schemaConverter.getPrimitiveTypes(), parquetWriterOptions, IntStream.range(0, dataColumns.size()).toArray(), compressionCodecName, trinoVersion, + Optional.empty(), Optional.empty()); } catch (IOException e) { @@ -310,7 +313,8 @@ private ReaderPageSource createParquetPageSource(Path path) true, parquetDateTimeZone, new FileFormatDataSourceStats(), - new ParquetReaderOptions()); + new ParquetReaderOptions(), + Optional.empty()); } private static class FileDeletion diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java index 85dc41403550..b4d0cf0218c1 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java @@ -496,12 +496,14 @@ private FileWriter createParquetFileWriter(Path path) fileSystem.create(path), rollbackAction, parquetTypes, + dataColumnNames, schemaConverter.getMessageType(), schemaConverter.getPrimitiveTypes(), parquetWriterOptions, identityMapping, compressionCodecName, trinoVersion, + Optional.empty(), Optional.empty()); } catch (IOException e) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index b3d7b321eba7..50039d5859dc 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -180,7 +180,8 @@ public ConnectorPageSource createPageSource( parquetDateTimeZone, fileFormatDataSourceStats, parquetReaderOptions.withMaxReadBlockSize(getParquetMaxReadBlockSize(session)) - .withUseColumnIndex(isParquetUseColumnIndex(session))); + .withUseColumnIndex(isParquetUseColumnIndex(session)), + Optional.empty()); verify(pageSource.getReaderColumns().isEmpty(), "All columns expected to be base columns"); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java index 46b33bb5d785..6aee909915af 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java @@ -578,7 +578,8 @@ private ReaderPageSource createParquetPageSource(TupleDomain p parquetDateTimeZone, new FileFormatDataSourceStats(), parquetReaderOptions.withMaxReadBlockSize(getParquetMaxReadBlockSize(this.session)) - .withUseColumnIndex(isParquetUseColumnIndex(this.session))); + .withUseColumnIndex(isParquetUseColumnIndex(this.session)), + Optional.empty()); } private DeltaLakeWriter createWriter(Path targetFile, List allColumns, List dataColumns) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 41f7c8f856c5..8a52ffc5efa2 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -181,7 +181,8 @@ public CheckpointEntryIterator( true, DateTimeZone.UTC, stats, - parquetReaderOptions); + parquetReaderOptions, + Optional.empty()); verify(pageSource.getReaderColumns().isEmpty(), "All columns expected to be base columns"); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java index 22052f828b87..98019b128764 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java @@ -89,6 +89,7 @@ public final class HiveSessionProperties private static final String PARQUET_WRITER_BLOCK_SIZE = "parquet_writer_block_size"; private static final String PARQUET_WRITER_PAGE_SIZE = "parquet_writer_page_size"; private static final String PARQUET_WRITER_BATCH_SIZE = "parquet_writer_batch_size"; + private static final String PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE = "parquet_optimized_writer_validation_percentage"; private static final String MAX_SPLIT_SIZE = "max_split_size"; private static final String MAX_INITIAL_SPLIT_SIZE = "max_initial_split_size"; private static final String RCFILE_OPTIMIZED_WRITER_VALIDATE = "rcfile_optimized_writer_validate"; @@ -338,6 +339,23 @@ public HiveSessionProperties( "Parquet: Maximum number of rows passed to the writer in each batch", parquetWriterConfig.getBatchSize(), false), + new PropertyMetadata<>( + PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE, + "Parquet: sample percentage for validation of written files", + DOUBLE, + Double.class, + parquetWriterConfig.getValidationPercentage(), + false, + value -> { + double doubleValue = (double) value; + if (doubleValue < 0.0 || doubleValue > 100.0) { + throw new TrinoException( + INVALID_SESSION_PROPERTY, + format("%s must be between 0.0 and 100.0 inclusive: %s", PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE, doubleValue)); + } + return doubleValue; + }, + value -> value), dataSizeProperty( MAX_SPLIT_SIZE, "Max split size", @@ -685,6 +703,13 @@ public static int getParquetBatchSize(ConnectorSession session) return session.getProperty(PARQUET_WRITER_BATCH_SIZE, Integer.class); } + public static boolean isParquetOptimizedWriterValidate(ConnectorSession session) + { + double percentage = session.getProperty(PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE, Double.class); + checkArgument(percentage >= 0.0 && percentage <= 100.0); + return ThreadLocalRandom.current().nextDouble(100) < percentage; + } + public static DataSize getMaxSplitSize(ConnectorSession session) { return session.getProperty(MAX_SPLIT_SIZE, DataSize.class); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/HiveParquetColumnIOConverter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/HiveParquetColumnIOConverter.java deleted file mode 100644 index 7487639861a7..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/HiveParquetColumnIOConverter.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive.parquet; - -import com.google.common.collect.ImmutableList; -import io.trino.parquet.Field; -import io.trino.parquet.GroupField; -import io.trino.parquet.PrimitiveField; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.MapType; -import io.trino.spi.type.RowType; -import io.trino.spi.type.Type; -import org.apache.parquet.io.ColumnIO; -import org.apache.parquet.io.GroupColumnIO; -import org.apache.parquet.io.PrimitiveColumnIO; - -import java.util.List; -import java.util.Locale; -import java.util.Optional; - -import static io.trino.parquet.ParquetTypeUtils.getArrayElementColumn; -import static io.trino.parquet.ParquetTypeUtils.getMapKeyValueColumn; -import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; -import static org.apache.parquet.io.ColumnIOUtil.columnDefinitionLevel; -import static org.apache.parquet.io.ColumnIOUtil.columnRepetitionLevel; -import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; - -public final class HiveParquetColumnIOConverter -{ - private HiveParquetColumnIOConverter() {} - - public static Optional constructField(Type type, ColumnIO columnIO) - { - if (columnIO == null) { - return Optional.empty(); - } - boolean required = columnIO.getType().getRepetition() != OPTIONAL; - int repetitionLevel = columnRepetitionLevel(columnIO); - int definitionLevel = columnDefinitionLevel(columnIO); - if (type instanceof RowType) { - RowType rowType = (RowType) type; - GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - ImmutableList.Builder> fieldsBuilder = ImmutableList.builder(); - List fields = rowType.getFields(); - boolean structHasParameters = false; - for (int i = 0; i < fields.size(); i++) { - RowType.Field rowField = fields.get(i); - String name = rowField.getName().orElseThrow().toLowerCase(Locale.ENGLISH); - Optional field = constructField(rowField.getType(), lookupColumnByName(groupColumnIO, name)); - structHasParameters |= field.isPresent(); - fieldsBuilder.add(field); - } - if (structHasParameters) { - return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, fieldsBuilder.build())); - } - return Optional.empty(); - } - if (type instanceof MapType) { - MapType mapType = (MapType) type; - GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); - if (keyValueColumnIO.getChildrenCount() != 2) { - return Optional.empty(); - } - Optional keyField = constructField(mapType.getKeyType(), keyValueColumnIO.getChild(0)); - Optional valueField = constructField(mapType.getValueType(), keyValueColumnIO.getChild(1)); - return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField))); - } - if (type instanceof ArrayType) { - ArrayType arrayType = (ArrayType) type; - GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - if (groupColumnIO.getChildrenCount() != 1) { - return Optional.empty(); - } - Optional field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0))); - return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field))); - } - PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO; - return Optional.of(new PrimitiveField(type, required, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId())); - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java index fe31b0ec367a..67ab1d3f4d36 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java @@ -14,6 +14,7 @@ package io.trino.plugin.hive.parquet; import com.google.common.collect.ImmutableList; +import io.trino.parquet.ParquetDataSource; import io.trino.parquet.writer.ParquetWriter; import io.trino.parquet.writer.ParquetWriterOptions; import io.trino.plugin.hive.FileWriter; @@ -31,40 +32,51 @@ import java.io.IOException; import java.io.OutputStream; import java.io.UncheckedIOException; +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.Callable; +import java.util.function.Supplier; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.trino.parquet.ParquetWriteValidation.ParquetWriteValidationBuilder; import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR; import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_DATA_ERROR; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITE_VALIDATION_FAILED; import static java.util.Objects.requireNonNull; public class ParquetFileWriter implements FileWriter { private static final int INSTANCE_SIZE = ClassLayout.parseClass(ParquetFileWriter.class).instanceSize(); + private static final ThreadMXBean THREAD_MX_BEAN = ManagementFactory.getThreadMXBean(); private final ParquetWriter parquetWriter; private final Callable rollbackAction; private final int[] fileInputColumnIndexes; private final List nullBlocks; + private final Optional> validationInputFactory; + private long validationCpuNanos; public ParquetFileWriter( OutputStream outputStream, Callable rollbackAction, List fileColumnTypes, + List fileColumnNames, MessageType messageType, Map, Type> primitiveTypes, ParquetWriterOptions parquetWriterOptions, int[] fileInputColumnIndexes, CompressionCodecName compressionCodecName, String trinoVersion, - Optional parquetTimeZone) + Optional parquetTimeZone, + Optional> validationInputFactory) { requireNonNull(outputStream, "outputStream is null"); requireNonNull(trinoVersion, "trinoVersion is null"); + this.validationInputFactory = requireNonNull(validationInputFactory, "validationInputFactory is null"); this.parquetWriter = new ParquetWriter( outputStream, @@ -73,7 +85,10 @@ public ParquetFileWriter( parquetWriterOptions, compressionCodecName, trinoVersion, - parquetTimeZone); + parquetTimeZone, + validationInputFactory.isPresent() + ? Optional.of(new ParquetWriteValidationBuilder(fileColumnTypes, fileColumnNames)) + : Optional.empty()); this.rollbackAction = requireNonNull(rollbackAction, "rollbackAction is null"); this.fileInputColumnIndexes = requireNonNull(fileInputColumnIndexes, "fileInputColumnIndexes is null"); @@ -136,6 +151,19 @@ public void commit() } throw new TrinoException(HIVE_WRITER_CLOSE_ERROR, "Error committing write parquet to Hive", e); } + + if (validationInputFactory.isPresent()) { + try { + try (ParquetDataSource input = validationInputFactory.get().get()) { + long startThreadCpuTime = THREAD_MX_BEAN.getCurrentThreadCpuTime(); + parquetWriter.validate(input); + validationCpuNanos += THREAD_MX_BEAN.getCurrentThreadCpuTime() - startThreadCpuTime; + } + } + catch (IOException | UncheckedIOException e) { + throw new TrinoException(HIVE_WRITE_VALIDATION_FAILED, e); + } + } } @Override @@ -157,7 +185,7 @@ public void rollback() @Override public long getValidationCpuNanos() { - return 0; + return validationCpuNanos; } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java index c8719ab71eb9..5f6d54eaf282 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java @@ -14,8 +14,12 @@ package io.trino.plugin.hive.parquet; import io.trino.hdfs.HdfsEnvironment; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.writer.ParquetSchemaConverter; import io.trino.parquet.writer.ParquetWriterOptions; +import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.FileWriter; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HiveFileWriterFactory; @@ -35,6 +39,8 @@ import org.apache.parquet.hadoop.ParquetOutputFormat; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.joda.time.DateTimeZone; +import org.weakref.jmx.Flatten; +import org.weakref.jmx.Managed; import javax.inject.Inject; @@ -44,11 +50,14 @@ import java.util.OptionalInt; import java.util.Properties; import java.util.concurrent.Callable; +import java.util.function.Supplier; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_INT96_TIMESTAMP_ENCODING; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_LEGACY_DECIMAL_ENCODING; import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_OPEN_ERROR; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITE_VALIDATION_FAILED; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; +import static io.trino.plugin.hive.HiveSessionProperties.isParquetOptimizedWriterValidate; import static io.trino.plugin.hive.util.HiveUtil.getColumnNames; import static io.trino.plugin.hive.util.HiveUtil.getColumnTypes; import static java.util.Objects.requireNonNull; @@ -61,18 +70,21 @@ public class ParquetFileWriterFactory private final NodeVersion nodeVersion; private final TypeManager typeManager; private final DateTimeZone parquetTimeZone; + private final FileFormatDataSourceStats readStats; @Inject public ParquetFileWriterFactory( HdfsEnvironment hdfsEnvironment, NodeVersion nodeVersion, TypeManager typeManager, - HiveConfig hiveConfig) + HiveConfig hiveConfig, + FileFormatDataSourceStats readStats) { this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.parquetTimeZone = hiveConfig.getParquetDateTimeZone(); + this.readStats = requireNonNull(readStats, "readStats is null"); } @Override @@ -127,23 +139,49 @@ public Optional createFileWriter( HIVE_PARQUET_USE_LEGACY_DECIMAL_ENCODING, HIVE_PARQUET_USE_INT96_TIMESTAMP_ENCODING); + Optional> validationInputFactory = Optional.empty(); + if (isParquetOptimizedWriterValidate(session)) { + validationInputFactory = Optional.of(() -> { + try { + return new HdfsParquetDataSource( + new ParquetDataSourceId(path.toString()), + fileSystem.getFileStatus(path).getLen(), + fileSystem.open(path), + readStats, + new ParquetReaderOptions()); + } + catch (IOException e) { + throw new TrinoException(HIVE_WRITE_VALIDATION_FAILED, e); + } + }); + } + return Optional.of(new ParquetFileWriter( fileSystem.create(path, false), rollbackAction, fileColumnTypes, + fileColumnNames, schemaConverter.getMessageType(), schemaConverter.getPrimitiveTypes(), parquetWriterOptions, fileInputColumnIndexes, compressionCodecName, nodeVersion.toString(), - Optional.of(parquetTimeZone))); + Optional.of(parquetTimeZone), + validationInputFactory)); } catch (IOException e) { throw new TrinoException(HIVE_WRITER_OPEN_ERROR, "Error creating Parquet file", e); } } + @Managed + @Flatten + public FileFormatDataSourceStats getReadStats() + { + return readStats; + } + private static CompressionCodecName getCompression(JobConf configuration) { String compressionName = configuration.get(ParquetOutputFormat.COMPRESSION); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java index de56ec42b302..3f1aed4ddd45 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java @@ -14,20 +14,17 @@ package io.trino.plugin.hive.parquet; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Streams; -import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; +import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.reader.ParquetReader; +import io.trino.parquet.reader.ParquetReaderColumn; import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import io.trino.spi.block.LazyBlockLoader; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.metrics.Metrics; -import io.trino.spi.type.Type; import java.io.IOException; import java.io.UncheckedIOException; @@ -35,69 +32,30 @@ import java.util.Optional; import java.util.OptionalLong; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static io.trino.plugin.base.util.Closables.closeAllSuppress; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; import static java.lang.String.format; -import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public class ParquetPageSource implements ConnectorPageSource { private final ParquetReader parquetReader; - private final List types; - private final List> fields; - /** - * Indicates whether the column at each index should be populated with the - * indices of its rows - */ - private final List rowIndexLocations; - - private int batchId; + private final List parquetReaderColumns; + private final boolean areSyntheticColumnsPresent; + private boolean closed; private long completedPositions; - public ParquetPageSource(ParquetReader parquetReader, List types, List> fields) - { - this(parquetReader, types, nCopies(types.size(), false), fields); - } - - /** - * @param types Column types - * @param rowIndexLocations Whether each column should be populated with the indices of its rows - * @param fields List of field descriptions. Empty optionals will result in columns populated with {@code NULL} - */ public ParquetPageSource( ParquetReader parquetReader, - List types, - List rowIndexLocations, - List> fields) + List parquetReaderColumns) { this.parquetReader = requireNonNull(parquetReader, "parquetReader is null"); - this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); - this.rowIndexLocations = requireNonNull(rowIndexLocations, "rowIndexLocations is null"); - this.fields = ImmutableList.copyOf(requireNonNull(fields, "fields is null")); - - // TODO: Instead of checking that the three list arguments go together correctly, - // we should do something like the ORC reader's ColumnAdatpation, using - // subclasses that contain only the necessary information for each column. - checkArgument( - types.size() == rowIndexLocations.size() && types.size() == fields.size(), - "types, rowIndexLocations, and fields must correspond one-to-one-to-one"); - Streams.forEachPair( - rowIndexLocations.stream(), - fields.stream(), - (isIndexColumn, field) -> checkArgument( - !(isIndexColumn && field.isPresent()), - "Field info for row index column must be empty Optional")); - } - - private boolean isIndexColumn(int column) - { - return rowIndexLocations.get(column); + this.parquetReaderColumns = ImmutableList.copyOf(requireNonNull(parquetReaderColumns, "parquetReaderColumns is null")); + this.areSyntheticColumnsPresent = parquetReaderColumns.stream() + .anyMatch(column -> column.isRowIndexColumn() || column.field().isEmpty()); } @Override @@ -133,39 +91,22 @@ public long getMemoryUsage() @Override public Page getNextPage() { + Page page; try { - batchId++; - int batchSize = parquetReader.nextBatch(); - - if (closed || batchSize <= 0) { - close(); - return null; - } - - completedPositions += batchSize; - - Block[] blocks = new Block[fields.size()]; - for (int column = 0; column < blocks.length; column++) { - if (isIndexColumn(column)) { - blocks[column] = getRowIndexColumn(parquetReader.lastBatchStartRow(), batchSize); - } - else { - Type type = types.get(column); - blocks[column] = fields.get(column) - .map(field -> new LazyBlock(batchSize, new ParquetBlockLoader(field))) - .orElseGet(() -> RunLengthEncodedBlock.create(type, null, batchSize)); - } - } - return new Page(batchSize, blocks); + page = getColumnAdaptationsPage(parquetReader.nextPage()); } - catch (TrinoException e) { + catch (IOException | RuntimeException e) { closeAllSuppress(e, this); - throw e; + throw handleException(parquetReader.getDataSource().getId(), e); } - catch (RuntimeException e) { - closeAllSuppress(e, this); - throw new TrinoException(HIVE_CURSOR_ERROR, e); + + if (closed || page == null) { + close(); + return null; } + + completedPositions += page.getPositionCount(); + return page; } @Override @@ -184,43 +125,48 @@ public void close() } } - private final class ParquetBlockLoader - implements LazyBlockLoader + @Override + public Metrics getMetrics() { - /** - * Stores batch ID at instantiation time. Loading fails if the ID - * changes before {@link #load()} is called. - */ - private final int expectedBatchId = batchId; - private final Field field; - private boolean loaded; - - public ParquetBlockLoader(Field field) - { - this.field = requireNonNull(field, "field is null"); - } - - @Override - public Block load() - { - checkState(!loaded, "Already loaded"); - checkState(batchId == expectedBatchId, "Inconsistent state; wrong batch"); + return new Metrics(parquetReader.getCodecMetrics()); + } - Block block; - String parquetDataSourceId = parquetReader.getDataSource().getId().toString(); - try { - block = parquetReader.readBlock(field); + private Page getColumnAdaptationsPage(Page page) + { + if (!areSyntheticColumnsPresent) { + return page; + } + if (page == null) { + return null; + } + int batchSize = page.getPositionCount(); + Block[] blocks = new Block[parquetReaderColumns.size()]; + int sourceColumn = 0; + for (int columnIndex = 0; columnIndex < parquetReaderColumns.size(); columnIndex++) { + ParquetReaderColumn column = parquetReaderColumns.get(columnIndex); + if (column.isRowIndexColumn()) { + blocks[columnIndex] = getRowIndexColumn(parquetReader.lastBatchStartRow(), batchSize); } - catch (ParquetCorruptionException e) { - throw new TrinoException(HIVE_BAD_DATA, format("Corrupted parquet data; source=%s; %s", parquetDataSourceId, e.getMessage()), e); + else if (column.field().isEmpty()) { + blocks[columnIndex] = RunLengthEncodedBlock.create(column.type(), null, batchSize); } - catch (IOException e) { - throw new TrinoException(HIVE_CURSOR_ERROR, format("Failed reading parquet data; source= %s; %s", parquetDataSourceId, e.getMessage()), e); + else { + blocks[columnIndex] = page.getBlock(sourceColumn); + sourceColumn++; } + } + return new Page(batchSize, blocks); + } - loaded = true; - return block; + static TrinoException handleException(ParquetDataSourceId dataSourceId, Exception exception) + { + if (exception instanceof TrinoException) { + return (TrinoException) exception; + } + if (exception instanceof ParquetCorruptionException) { + return new TrinoException(HIVE_BAD_DATA, exception); } + return new TrinoException(HIVE_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); } private static Block getRowIndexColumn(long baseIndex, int size) @@ -231,10 +177,4 @@ private static Block getRowIndexColumn(long baseIndex, int size) } return new LongArrayBlock(size, Optional.empty(), rowIndices); } - - @Override - public Metrics getMetrics() - { - return new Metrics(parquetReader.getCodecMetrics()); - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index 13a26ad990f5..37eb876d144f 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -19,13 +19,15 @@ import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoFileSystemFactory; import io.trino.filesystem.TrinoInputFile; -import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.MetadataReader; import io.trino.parquet.reader.ParquetReader; +import io.trino.parquet.reader.ParquetReaderColumn; import io.trino.parquet.reader.TrinoColumnIndexStore; import io.trino.plugin.hive.AcidInfo; import io.trino.plugin.hive.FileFormatDataSourceStats; @@ -41,7 +43,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; -import io.trino.spi.type.Type; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hdfs.BlockMissingException; @@ -72,14 +73,17 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.trino.parquet.ParquetTypeUtils.constructField; import static io.trino.parquet.ParquetTypeUtils.getColumnIO; import static io.trino.parquet.ParquetTypeUtils.getDescriptors; import static io.trino.parquet.ParquetTypeUtils.getParquetTypeByName; import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; import static io.trino.parquet.predicate.PredicateUtils.buildPredicate; import static io.trino.parquet.predicate.PredicateUtils.predicateMatches; +import static io.trino.parquet.reader.ParquetReaderColumn.getParquetReaderFields; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; @@ -90,7 +94,7 @@ import static io.trino.plugin.hive.HiveSessionProperties.isParquetIgnoreStatistics; import static io.trino.plugin.hive.HiveSessionProperties.isParquetUseColumnIndex; import static io.trino.plugin.hive.HiveSessionProperties.isUseParquetColumnNames; -import static io.trino.plugin.hive.parquet.HiveParquetColumnIOConverter.constructField; +import static io.trino.plugin.hive.parquet.ParquetPageSource.handleException; import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; import static io.trino.spi.type.BigintType.BIGINT; import static java.lang.String.format; @@ -174,7 +178,8 @@ public Optional createPageSource( stats, options.withIgnoreStatistics(isParquetIgnoreStatistics(session)) .withMaxReadBlockSize(getParquetMaxReadBlockSize(session)) - .withUseColumnIndex(isParquetUseColumnIndex(session)))); + .withUseColumnIndex(isParquetUseColumnIndex(session)), + Optional.empty())); } /** @@ -189,7 +194,8 @@ public static ReaderPageSource createPageSource( boolean useColumnNames, DateTimeZone timeZone, FileFormatDataSourceStats stats, - ParquetReaderOptions options) + ParquetReaderOptions options, + Optional parquetWriteValidation) { // Ignore predicates on partial columns for now. effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn()); @@ -197,12 +203,11 @@ public static ReaderPageSource createPageSource( MessageType fileSchema; MessageType requestedSchema; MessageColumnIO messageColumn; - ParquetReader parquetReader; ParquetDataSource dataSource = null; try { dataSource = new TrinoParquetDataSource(inputFile, options, stats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, parquetWriteValidation); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); fileSchema = fileMetaData.getSchema(); @@ -250,47 +255,24 @@ && predicateMatches(parquetPredicate, block, dataSource, descriptorsByPath, parq .map(HiveColumnHandle.class::cast) .collect(toUnmodifiableList())) .orElse(columns); + List parquetReaderColumns = createParquetReaderColumns(baseColumns, fileSchema, messageColumn, useColumnNames); - for (HiveColumnHandle column : baseColumns) { - checkArgument(column == PARQUET_ROW_INDEX_COLUMN || column.getColumnType() == REGULAR, "column type must be REGULAR: %s", column); - } - - ImmutableList.Builder trinoTypes = ImmutableList.builder(); - ImmutableList.Builder> internalFieldsBuilder = ImmutableList.builder(); - ImmutableList.Builder rowIndexColumns = ImmutableList.builder(); - for (HiveColumnHandle column : baseColumns) { - trinoTypes.add(column.getBaseType()); - rowIndexColumns.add(column == PARQUET_ROW_INDEX_COLUMN); - if (column == PARQUET_ROW_INDEX_COLUMN) { - internalFieldsBuilder.add(Optional.empty()); - } - else { - internalFieldsBuilder.add(Optional.ofNullable(getParquetType(column, fileSchema, useColumnNames)) - .flatMap(field -> { - String columnName = useColumnNames ? column.getBaseColumnName() : fileSchema.getFields().get(column.getBaseHiveColumnIndex()).getName(); - return constructField(column.getBaseType(), lookupColumnByName(messageColumn, columnName)); - })); - } - } - - List> internalFields = internalFieldsBuilder.build(); - parquetReader = new ParquetReader( - Optional.ofNullable(fileMetaData.getCreatedBy()), - internalFields, - blocks.build(), - blockStarts.build(), - dataSource, - timeZone, - newSimpleAggregatedMemoryContext(), - options, - Optional.of(parquetPredicate), - columnIndexes.build()); - + ParquetDataSourceId dataSourceId = dataSource.getId(); ConnectorPageSource parquetPageSource = new ParquetPageSource( - parquetReader, - trinoTypes.build(), - rowIndexColumns.build(), - internalFields); + new ParquetReader( + Optional.ofNullable(fileMetaData.getCreatedBy()), + getParquetReaderFields(parquetReaderColumns), + blocks.build(), + blockStarts.build(), + dataSource, + timeZone, + newSimpleAggregatedMemoryContext(), + options, + exception -> handleException(dataSourceId, exception), + Optional.of(parquetPredicate), + columnIndexes.build(), + parquetWriteValidation), + parquetReaderColumns); return new ReaderPageSource(parquetPageSource, readerProjections); } catch (Exception e) { @@ -445,4 +427,25 @@ private static org.apache.parquet.schema.Type getParquetType(HiveColumnHandle co } return null; } + + private static List createParquetReaderColumns(List baseColumns, MessageType fileSchema, MessageColumnIO messageColumn, boolean useColumnNames) + { + for (HiveColumnHandle column : baseColumns) { + checkArgument(column == PARQUET_ROW_INDEX_COLUMN || column.getColumnType() == REGULAR, "column type must be REGULAR: %s", column); + } + + return baseColumns.stream() + .map(column -> { + boolean isRowIndexColumn = column == PARQUET_ROW_INDEX_COLUMN; + return new ParquetReaderColumn( + column.getBaseType(), + isRowIndexColumn ? Optional.empty() : Optional.ofNullable(getParquetType(column, fileSchema, useColumnNames)) + .flatMap(field -> { + String columnName = useColumnNames ? column.getBaseColumnName() : fileSchema.getFields().get(column.getBaseHiveColumnIndex()).getName(); + return constructField(column.getBaseType(), lookupColumnByName(messageColumn, columnName)); + }), + isRowIndexColumn); + }) + .collect(toImmutableList()); + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java index 2415e8fa41f6..1b2f7653aa4e 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java @@ -20,6 +20,9 @@ import io.trino.parquet.writer.ParquetWriterOptions; import org.apache.parquet.hadoop.ParquetWriter; +import javax.validation.constraints.DecimalMax; +import javax.validation.constraints.DecimalMin; + public class ParquetWriterConfig { private boolean parquetOptimizedWriterEnabled; @@ -27,6 +30,7 @@ public class ParquetWriterConfig private DataSize blockSize = DataSize.ofBytes(ParquetWriter.DEFAULT_BLOCK_SIZE); private DataSize pageSize = DataSize.ofBytes(ParquetWriter.DEFAULT_PAGE_SIZE); private int batchSize = ParquetWriterOptions.DEFAULT_BATCH_SIZE; + private double validationPercentage = 5; public DataSize getBlockSize() { @@ -80,4 +84,19 @@ public int getBatchSize() { return batchSize; } + + @DecimalMin("0.0") + @DecimalMax("100.0") + public double getValidationPercentage() + { + return validationPercentage; + } + + @Config("parquet.optimized-writer.validation-percentage") + @ConfigDescription("Percentage of parquet files to validate after write by re-reading the whole file") + public ParquetWriterConfig setValidationPercentage(double validationPercentage) + { + this.validationPercentage = validationPercentage; + return this; + } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java index af479988f7be..8fdaf598cf1b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java @@ -443,7 +443,11 @@ public void testParquetPageSourceGzip(int rowCount, long fileSizePadding) public void testOptimizedParquetWriter(int rowCount) throws Exception { - ConnectorSession session = getHiveSession(new HiveConfig(), new ParquetWriterConfig().setParquetOptimizedWriterEnabled(true)); + ConnectorSession session = getHiveSession( + new HiveConfig(), + new ParquetWriterConfig() + .setParquetOptimizedWriterEnabled(true) + .setValidationPercentage(100.0)); assertTrue(HiveSessionProperties.isParquetOptimizedWriterEnabled(session)); List testColumns = getTestColumnsSupportedByParquet(); @@ -451,7 +455,7 @@ public void testOptimizedParquetWriter(int rowCount) .withSession(session) .withColumns(testColumns) .withRowsCount(rowCount) - .withFileWriterFactory(new ParquetFileWriterFactory(HDFS_ENVIRONMENT, new NodeVersion("test-version"), TESTING_TYPE_MANAGER, new HiveConfig())) + .withFileWriterFactory(new ParquetFileWriterFactory(HDFS_ENVIRONMENT, new NodeVersion("test-version"), TESTING_TYPE_MANAGER, new HiveConfig(), STATS)) .isReadableByPageSource(PARQUET_PAGE_SOURCE_FACTORY); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java index 6463abcbceca..a0ccde2547b2 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java @@ -319,7 +319,8 @@ public PrestoParquetFormatWriter(File targetFile, List columnNames, List ParquetWriterOptions.builder().build(), compressionCodec.getParquetCompressionCodec(), "test-version", - Optional.of(DateTimeZone.getDefault())); + Optional.of(DateTimeZone.getDefault()), + Optional.empty()); } @Override diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java index f2967164076c..566fc1d08f6a 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java @@ -22,9 +22,12 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.writer.ParquetSchemaConverter; import io.trino.parquet.writer.ParquetWriter; import io.trino.parquet.writer.ParquetWriterOptions; +import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HiveSessionProperties; import io.trino.plugin.hive.HiveStorageFormat; @@ -38,6 +41,7 @@ import io.trino.plugin.hive.parquet.write.TestMapredParquetOutputFormat; import io.trino.spi.Page; import io.trino.spi.PageBuilder; +import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorPageSource; @@ -59,6 +63,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.testing.TestingConnectorSession; +import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe; @@ -101,9 +106,11 @@ import static com.google.common.collect.Iterables.transform; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.hadoop.ConfigurationInstantiator.newEmptyConfiguration; +import static io.trino.parquet.ParquetWriteValidation.ParquetWriteValidationBuilder; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_INT96_TIMESTAMP_ENCODING; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_LEGACY_DECIMAL_ENCODING; import static io.trino.plugin.hive.AbstractTestHiveFileFormats.getFieldFromCursor; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITE_VALIDATION_FAILED; import static io.trino.plugin.hive.HiveSessionProperties.getParquetMaxReadBlockSize; import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static io.trino.plugin.hive.HiveTestUtils.getHiveSession; @@ -751,7 +758,8 @@ private static void writeParquetColumnTrino( .build(), compressionCodecName, "test-version", - Optional.of(DateTimeZone.getDefault())); + Optional.of(DateTimeZone.getDefault()), + Optional.of(new ParquetWriteValidationBuilder(types, columnNames))); PageBuilder pageBuilder = new PageBuilder(types); for (int i = 0; i < types.size(); ++i) { @@ -768,6 +776,19 @@ private static void writeParquetColumnTrino( pageBuilder.declarePositions(size); writer.write(pageBuilder.build()); writer.close(); + Path path = new Path(outputFile.getPath()); + FileSystem fileSystem = HDFS_ENVIRONMENT.getFileSystem(SESSION.getIdentity(), path, newEmptyConfiguration()); + try { + writer.validate(new HdfsParquetDataSource( + new ParquetDataSourceId(path.toString()), + fileSystem.getFileStatus(path).getLen(), + fileSystem.open(path), + new FileFormatDataSourceStats(), + new ParquetReaderOptions())); + } + catch (IOException e) { + throw new TrinoException(HIVE_WRITE_VALIDATION_FAILED, e); + } } private static void writeValue(Type type, BlockBuilder blockBuilder, Object value) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java index ce1f83ab3ea4..c72fca4a075e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java @@ -35,7 +35,8 @@ public void testDefaults() .setParquetOptimizedWriterEnabled(false) .setBlockSize(DataSize.ofBytes(ParquetWriter.DEFAULT_BLOCK_SIZE)) .setPageSize(DataSize.ofBytes(ParquetWriter.DEFAULT_PAGE_SIZE)) - .setBatchSize(ParquetWriterOptions.DEFAULT_BATCH_SIZE)); + .setBatchSize(ParquetWriterOptions.DEFAULT_BATCH_SIZE) + .setValidationPercentage(5)); } @Test @@ -60,13 +61,15 @@ public void testExplicitPropertyMappings() "parquet.experimental-optimized-writer.enabled", "true", "parquet.writer.block-size", "234MB", "parquet.writer.page-size", "11MB", - "parquet.writer.batch-size", "100"); + "parquet.writer.batch-size", "100", + "parquet.optimized-writer.validation-percentage", "10"); ParquetWriterConfig expected = new ParquetWriterConfig() .setParquetOptimizedWriterEnabled(true) .setBlockSize(DataSize.of(234, MEGABYTE)) .setPageSize(DataSize.of(11, MEGABYTE)) - .setBatchSize(100); + .setBatchSize(100) + .setValidationPercentage(10); assertFullMapping(properties, expected); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java index bc0562c4c6ba..52c370983911 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java @@ -188,6 +188,7 @@ private IcebergFileWriter createParquetWriter( outputStream, rollbackAction, fileColumnTypes, + fileColumnNames, convert(icebergSchema, "table"), makeTypeMap(fileColumnTypes, fileColumnNames), parquetWriterOptions, diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 57370e422807..4ab1227a7cf3 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -35,13 +35,14 @@ import io.trino.orc.TupleDomainOrcPredicate; import io.trino.orc.TupleDomainOrcPredicate.TupleDomainOrcPredicateBuilder; import io.trino.orc.metadata.OrcType; -import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.MetadataReader; import io.trino.parquet.reader.ParquetReader; +import io.trino.parquet.reader.ParquetReaderColumn; import io.trino.plugin.base.classloader.ClassLoaderSafeUpdatablePageSource; import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.ReaderColumns; @@ -139,6 +140,7 @@ import static io.trino.parquet.ParquetTypeUtils.getDescriptors; import static io.trino.parquet.predicate.PredicateUtils.buildPredicate; import static io.trino.parquet.predicate.PredicateUtils.predicateMatches; +import static io.trino.parquet.reader.ParquetReaderColumn.getParquetReaderFields; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_DATA; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_SPEC_ID; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_BAD_DATA; @@ -928,7 +930,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( ParquetDataSource dataSource = null; try { dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); if (nameMapping.isPresent() && !ParquetSchemaUtil.hasIds(fileSchema)) { @@ -979,9 +981,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( ConstantPopulatingPageSource.Builder constantPopulatingPageSourceBuilder = ConstantPopulatingPageSource.builder(); int parquetSourceChannel = 0; - ImmutableList.Builder trinoTypes = ImmutableList.builder(); - ImmutableList.Builder> internalFieldsBuilder = ImmutableList.builder(); - ImmutableList.Builder rowIndexChannels = ImmutableList.builder(); + ImmutableList.Builder parquetReaderColumnBuilder = ImmutableList.builder(); for (int columnIndex = 0; columnIndex < readColumns.size(); columnIndex++) { IcebergColumnHandle column = readColumns.get(columnIndex); if (column.isIsDeletedColumn()) { @@ -1001,16 +1001,12 @@ else if (column.isFileModifiedTimeColumn()) { } else if (column.isUpdateRowIdColumn() || column.isMergeRowIdColumn()) { // $row_id is a composite of multiple physical columns, it is assembled by the IcebergPageSource - trinoTypes.add(column.getType()); - internalFieldsBuilder.add(Optional.empty()); - rowIndexChannels.add(false); + parquetReaderColumnBuilder.add(new ParquetReaderColumn(column.getType(), Optional.empty(), false)); constantPopulatingPageSourceBuilder.addDelegateColumn(parquetSourceChannel); parquetSourceChannel++; } else if (column.isRowPositionColumn()) { - trinoTypes.add(BIGINT); - internalFieldsBuilder.add(Optional.empty()); - rowIndexChannels.add(true); + parquetReaderColumnBuilder.add(new ParquetReaderColumn(BIGINT, Optional.empty(), true)); constantPopulatingPageSourceBuilder.addDelegateColumn(parquetSourceChannel); parquetSourceChannel++; } @@ -1021,18 +1017,19 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), utf8Slice(partitionData))); } else { - rowIndexChannels.add(false); org.apache.parquet.schema.Type parquetField = parquetFields.get(columnIndex); Type trinoType = column.getBaseType(); - trinoTypes.add(trinoType); if (parquetField == null) { - internalFieldsBuilder.add(Optional.empty()); + parquetReaderColumnBuilder.add(new ParquetReaderColumn(trinoType, Optional.empty(), false)); } else { // The top level columns are already mapped by name/id appropriately. ColumnIO columnIO = messageColumnIO.getChild(parquetField.getName()); - internalFieldsBuilder.add(IcebergParquetColumnIOConverter.constructField(new FieldContext(trinoType, column.getColumnIdentity()), columnIO)); + parquetReaderColumnBuilder.add(new ParquetReaderColumn( + trinoType, + IcebergParquetColumnIOConverter.constructField(new FieldContext(trinoType, column.getColumnIdentity()), columnIO), + false)); } constantPopulatingPageSourceBuilder.addDelegateColumn(parquetSourceChannel); @@ -1040,21 +1037,21 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { } } - List> internalFields = internalFieldsBuilder.build(); + List parquetReaderColumns = parquetReaderColumnBuilder.build(); + ParquetDataSourceId dataSourceId = dataSource.getId(); ParquetReader parquetReader = new ParquetReader( Optional.ofNullable(fileMetaData.getCreatedBy()), - internalFields, + getParquetReaderFields(parquetReaderColumns), blocks, blockStarts.build(), dataSource, UTC, memoryContext, options, - Optional.empty()); - + exception -> handleException(dataSourceId, exception)); return new ReaderPageSourceWithRowPositions( new ReaderPageSource( - constantPopulatingPageSourceBuilder.build(new ParquetPageSource(parquetReader, trinoTypes.build(), rowIndexChannels.build(), internalFields)), + constantPopulatingPageSourceBuilder.build(new ParquetPageSource(parquetReader, parquetReaderColumns)), columnProjections), startRowPosition, endRowPosition); @@ -1353,6 +1350,17 @@ private static TrinoException handleException(OrcDataSourceId dataSourceId, Exce return new TrinoException(ICEBERG_CURSOR_ERROR, format("Failed to read ORC file: %s", dataSourceId), exception); } + private static TrinoException handleException(ParquetDataSourceId dataSourceId, Exception exception) + { + if (exception instanceof TrinoException) { + return (TrinoException) exception; + } + if (exception instanceof ParquetCorruptionException) { + return new TrinoException(ICEBERG_BAD_DATA, exception); + } + return new TrinoException(ICEBERG_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); + } + public static final class ReaderPageSourceWithRowPositions { private final ReaderPageSource readerPageSource; diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java index 446fe586a2be..c4d7ec0e34d2 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java @@ -45,6 +45,7 @@ public IcebergParquetFileWriter( OutputStream outputStream, Callable rollbackAction, List fileColumnTypes, + List fileColumnNames, MessageType messageType, Map, Type> primitiveTypes, ParquetWriterOptions parquetWriterOptions, @@ -57,12 +58,14 @@ public IcebergParquetFileWriter( super(outputStream, rollbackAction, fileColumnTypes, + fileColumnNames, messageType, primitiveTypes, parquetWriterOptions, fileInputColumnIndexes, compressionCodecName, trinoVersion, + Optional.empty(), Optional.empty()); this.metricsConfig = requireNonNull(metricsConfig, "metricsConfig is null"); this.outputPath = requireNonNull(outputPath, "outputPath is null");