From a6922ac50fb446fa5dcdef59c88f06859eb96e62 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Fri, 5 Aug 2022 13:57:56 +0530 Subject: [PATCH 1/4] Remove unused ParquetPageSource constructor --- .../io/trino/plugin/hive/parquet/ParquetPageSource.java | 6 ------ 1 file changed, 6 deletions(-) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java index de56ec42b302..df4a02ee7444 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java @@ -41,7 +41,6 @@ import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; import static java.lang.String.format; -import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public class ParquetPageSource @@ -60,11 +59,6 @@ public class ParquetPageSource private boolean closed; private long completedPositions; - public ParquetPageSource(ParquetReader parquetReader, List types, List> fields) - { - this(parquetReader, types, nCopies(types.size(), false), fields); - } - /** * @param types Column types * @param rowIndexLocations Whether each column should be populated with the indices of its rows From 0f832dbd21309f42f512f7c618518f373e89c081 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Fri, 5 Aug 2022 15:23:24 +0530 Subject: [PATCH 2/4] Allow creation of pages from ParquetReader Added ParquetBlockFactory along similar lines as OrcBlockFactory to handle creation of lazy blocks and addition of connector specific error codes to exceptions. This change makes it possible for the writer to perform validation without having to rely on ConnectorPageSource. --- .../parquet/reader/ParquetBlockFactory.java | 79 +++++++++ .../trino/parquet/reader/ParquetReader.java | 50 ++++-- .../parquet/reader/ParquetReaderColumn.java | 52 ++++++ .../hive/parquet/ParquetPageSource.java | 164 ++++++------------ .../parquet/ParquetPageSourceFactory.java | 83 +++++---- .../iceberg/IcebergPageSourceProvider.java | 46 +++-- 6 files changed, 289 insertions(+), 185 deletions(-) create mode 100644 lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetBlockFactory.java create mode 100644 lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReaderColumn.java diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetBlockFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetBlockFactory.java new file mode 100644 index 000000000000..241c3db01bd6 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetBlockFactory.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.spi.block.Block; +import io.trino.spi.block.LazyBlock; +import io.trino.spi.block.LazyBlockLoader; + +import java.io.IOException; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class ParquetBlockFactory +{ + private final Function exceptionTransform; + private int currentPageId; + + public ParquetBlockFactory(Function exceptionTransform) + { + this.exceptionTransform = requireNonNull(exceptionTransform, "exceptionTransform is null"); + } + + public void nextPage() + { + currentPageId++; + } + + public Block createBlock(int positionCount, ParquetBlockReader reader) + { + return new LazyBlock(positionCount, new ParquetBlockLoader(reader)); + } + + public interface ParquetBlockReader + { + Block readBlock() + throws IOException; + } + + private final class ParquetBlockLoader + implements LazyBlockLoader + { + private final int expectedPageId = currentPageId; + private final ParquetBlockReader blockReader; + private boolean loaded; + + public ParquetBlockLoader(ParquetBlockReader blockReader) + { + this.blockReader = requireNonNull(blockReader, "blockReader is null"); + } + + @Override + public Block load() + { + checkState(!loaded, "Already loaded"); + checkState(currentPageId == expectedPageId, "Parquet reader has been advanced beyond block"); + + loaded = true; + try { + return blockReader.readBlock(); + } + catch (IOException | RuntimeException e) { + throw exceptionTransform.apply(e); + } + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 582fd4ca8d6e..b325fe508447 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -33,6 +33,7 @@ import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.FilteredOffsetIndex.OffsetRange; import io.trino.plugin.base.metrics.LongCount; +import io.trino.spi.Page; import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.RowBlock; @@ -67,6 +68,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -90,7 +92,8 @@ public class ParquetReader private final Optional fileCreatedBy; private final List blocks; private final List firstRowsOfBlocks; - private final List fields; + private final List columnFields; + private final List primitiveFields; private final ParquetDataSource dataSource; private final DateTimeZone timeZone; private final AggregatedMemoryContext memoryContext; @@ -121,39 +124,42 @@ public class ParquetReader private final List> columnIndexStore; private final List blockRowRanges; private final Map paths = new HashMap<>(); + private final ParquetBlockFactory blockFactory; private final Map> codecMetrics; public ParquetReader( Optional fileCreatedBy, - List> fields, + List columnFields, List blocks, List firstRowsOfBlocks, ParquetDataSource dataSource, DateTimeZone timeZone, AggregatedMemoryContext memoryContext, ParquetReaderOptions options, - Optional parquetPredicate) + Function exceptionTransform) throws IOException { - this(fileCreatedBy, fields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, parquetPredicate, nCopies(blocks.size(), Optional.empty())); + this(fileCreatedBy, columnFields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, exceptionTransform, Optional.empty(), nCopies(blocks.size(), Optional.empty())); } public ParquetReader( Optional fileCreatedBy, - List> fields, + List columnFields, List blocks, List firstRowsOfBlocks, ParquetDataSource dataSource, DateTimeZone timeZone, AggregatedMemoryContext memoryContext, ParquetReaderOptions options, + Function exceptionTransform, Optional parquetPredicate, List> columnIndexStore) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); - List primitiveFields = getPrimitiveFields(requireNonNull(fields, "fields is null")); - this.fields = primitiveFields; + requireNonNull(columnFields, "columnFields is null"); + this.columnFields = ImmutableList.copyOf(columnFields); + this.primitiveFields = getPrimitiveFields(columnFields); this.blocks = requireNonNull(blocks, "blocks is null"); this.firstRowsOfBlocks = requireNonNull(firstRowsOfBlocks, "firstRowsOfBlocks is null"); this.dataSource = requireNonNull(dataSource, "dataSource is null"); @@ -180,6 +186,7 @@ public ParquetReader( else { this.filter = Optional.empty(); } + this.blockFactory = new ParquetBlockFactory(exceptionTransform); ListMultimap ranges = ArrayListMultimap.create(); Map codecMetrics = new HashMap<>(); for (int rowGroup = 0; rowGroup < blocks.size(); rowGroup++) { @@ -226,6 +233,22 @@ public void close() dataSource.close(); } + public Page nextPage() + { + int batchSize = nextBatch(); + if (batchSize <= 0) { + return null; + } + // create a lazy page + blockFactory.nextPage(); + Block[] blocks = new Block[columnFields.size()]; + for (int channel = 0; channel < columnFields.size(); channel++) { + Field field = columnFields.get(channel); + blocks[channel] = blockFactory.createBlock(batchSize, () -> readBlock(field)); + } + return new Page(batchSize, blocks); + } + /** * Get the global row index of the first row in the last batch. */ @@ -234,7 +257,7 @@ public long lastBatchStartRow() return firstRowIndexInGroup + nextRowInGroup - batchSize; } - public int nextBatch() + private int nextBatch() { if (nextRowInGroup >= currentGroupRowCount && !advanceToNextRowGroup()) { return -1; @@ -282,7 +305,7 @@ private void freeCurrentRowGroupBuffers() return; } - for (int column = 0; column < fields.size(); column++) { + for (int column = 0; column < primitiveFields.size(); column++) { Collection readers = chunkReaders.get(new ChunkKey(column, currentRowGroup)); if (readers != null) { for (ChunkReader reader : readers) { @@ -433,18 +456,15 @@ private ColumnChunkMetaData getColumnChunkMetaData(BlockMetaData blockMetaData, private void initializeColumnReaders() { - for (PrimitiveField field : fields) { + for (PrimitiveField field : primitiveFields) { columnReaders.put(field.getId(), PrimitiveColumnReader.createReader(field, timeZone)); } } - public static List getPrimitiveFields(List> fields) + public static List getPrimitiveFields(List fields) { Map primitiveFields = new HashMap<>(); - - fields.stream() - .flatMap(Optional::stream) - .forEach(field -> parseField(field, primitiveFields)); + fields.forEach(field -> parseField(field, primitiveFields)); return ImmutableList.copyOf(primitiveFields.values()); } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReaderColumn.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReaderColumn.java new file mode 100644 index 000000000000..a2060e273711 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReaderColumn.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.parquet.Field; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * @param type Column type + * @param field Field description. Empty optional will result in column populated with {@code NULL} + * @param isRowIndexColumn Whether column should be populated with the indices of its rows + */ +public record ParquetReaderColumn(Type type, Optional field, boolean isRowIndexColumn) +{ + public static List getParquetReaderFields(List parquetReaderColumns) + { + return parquetReaderColumns.stream() + .filter(column -> !column.isRowIndexColumn()) + .map(ParquetReaderColumn::field) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableList()); + } + + public ParquetReaderColumn(Type type, Optional field, boolean isRowIndexColumn) + { + this.type = requireNonNull(type, "type is null"); + this.field = requireNonNull(field, "field is null"); + checkArgument( + !isRowIndexColumn || field.isEmpty(), + "Field info for row index column must be empty Optional"); + this.isRowIndexColumn = isRowIndexColumn; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java index df4a02ee7444..f035b82ec8d7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java @@ -14,20 +14,17 @@ package io.trino.plugin.hive.parquet; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Streams; -import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; +import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.reader.ParquetReader; +import io.trino.parquet.reader.ParquetReaderColumn; import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import io.trino.spi.block.LazyBlockLoader; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.metrics.Metrics; -import io.trino.spi.type.Type; import java.io.IOException; import java.io.UncheckedIOException; @@ -35,8 +32,6 @@ import java.util.Optional; import java.util.OptionalLong; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static io.trino.plugin.base.util.Closables.closeAllSuppress; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; @@ -47,51 +42,20 @@ public class ParquetPageSource implements ConnectorPageSource { private final ParquetReader parquetReader; - private final List types; - private final List> fields; - /** - * Indicates whether the column at each index should be populated with the - * indices of its rows - */ - private final List rowIndexLocations; - - private int batchId; + private final List parquetReaderColumns; + private final boolean areSyntheticColumnsPresent; + private boolean closed; private long completedPositions; - /** - * @param types Column types - * @param rowIndexLocations Whether each column should be populated with the indices of its rows - * @param fields List of field descriptions. Empty optionals will result in columns populated with {@code NULL} - */ public ParquetPageSource( ParquetReader parquetReader, - List types, - List rowIndexLocations, - List> fields) + List parquetReaderColumns) { this.parquetReader = requireNonNull(parquetReader, "parquetReader is null"); - this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); - this.rowIndexLocations = requireNonNull(rowIndexLocations, "rowIndexLocations is null"); - this.fields = ImmutableList.copyOf(requireNonNull(fields, "fields is null")); - - // TODO: Instead of checking that the three list arguments go together correctly, - // we should do something like the ORC reader's ColumnAdatpation, using - // subclasses that contain only the necessary information for each column. - checkArgument( - types.size() == rowIndexLocations.size() && types.size() == fields.size(), - "types, rowIndexLocations, and fields must correspond one-to-one-to-one"); - Streams.forEachPair( - rowIndexLocations.stream(), - fields.stream(), - (isIndexColumn, field) -> checkArgument( - !(isIndexColumn && field.isPresent()), - "Field info for row index column must be empty Optional")); - } - - private boolean isIndexColumn(int column) - { - return rowIndexLocations.get(column); + this.parquetReaderColumns = ImmutableList.copyOf(requireNonNull(parquetReaderColumns, "parquetReaderColumns is null")); + this.areSyntheticColumnsPresent = parquetReaderColumns.stream() + .anyMatch(column -> column.isRowIndexColumn() || column.field().isEmpty()); } @Override @@ -127,39 +91,22 @@ public long getMemoryUsage() @Override public Page getNextPage() { + Page page; try { - batchId++; - int batchSize = parquetReader.nextBatch(); - - if (closed || batchSize <= 0) { - close(); - return null; - } - - completedPositions += batchSize; - - Block[] blocks = new Block[fields.size()]; - for (int column = 0; column < blocks.length; column++) { - if (isIndexColumn(column)) { - blocks[column] = getRowIndexColumn(parquetReader.lastBatchStartRow(), batchSize); - } - else { - Type type = types.get(column); - blocks[column] = fields.get(column) - .map(field -> new LazyBlock(batchSize, new ParquetBlockLoader(field))) - .orElseGet(() -> RunLengthEncodedBlock.create(type, null, batchSize)); - } - } - return new Page(batchSize, blocks); - } - catch (TrinoException e) { - closeAllSuppress(e, this); - throw e; + page = getColumnAdaptationsPage(parquetReader.nextPage()); } catch (RuntimeException e) { closeAllSuppress(e, this); - throw new TrinoException(HIVE_CURSOR_ERROR, e); + throw handleException(parquetReader.getDataSource().getId(), e); + } + + if (closed || page == null) { + close(); + return null; } + + completedPositions += page.getPositionCount(); + return page; } @Override @@ -178,43 +125,48 @@ public void close() } } - private final class ParquetBlockLoader - implements LazyBlockLoader + @Override + public Metrics getMetrics() { - /** - * Stores batch ID at instantiation time. Loading fails if the ID - * changes before {@link #load()} is called. - */ - private final int expectedBatchId = batchId; - private final Field field; - private boolean loaded; - - public ParquetBlockLoader(Field field) - { - this.field = requireNonNull(field, "field is null"); - } - - @Override - public Block load() - { - checkState(!loaded, "Already loaded"); - checkState(batchId == expectedBatchId, "Inconsistent state; wrong batch"); + return new Metrics(parquetReader.getCodecMetrics()); + } - Block block; - String parquetDataSourceId = parquetReader.getDataSource().getId().toString(); - try { - block = parquetReader.readBlock(field); + private Page getColumnAdaptationsPage(Page page) + { + if (!areSyntheticColumnsPresent) { + return page; + } + if (page == null) { + return null; + } + int batchSize = page.getPositionCount(); + Block[] blocks = new Block[parquetReaderColumns.size()]; + int sourceColumn = 0; + for (int columnIndex = 0; columnIndex < parquetReaderColumns.size(); columnIndex++) { + ParquetReaderColumn column = parquetReaderColumns.get(columnIndex); + if (column.isRowIndexColumn()) { + blocks[columnIndex] = getRowIndexColumn(parquetReader.lastBatchStartRow(), batchSize); } - catch (ParquetCorruptionException e) { - throw new TrinoException(HIVE_BAD_DATA, format("Corrupted parquet data; source=%s; %s", parquetDataSourceId, e.getMessage()), e); + else if (column.field().isEmpty()) { + blocks[columnIndex] = RunLengthEncodedBlock.create(column.type(), null, batchSize); } - catch (IOException e) { - throw new TrinoException(HIVE_CURSOR_ERROR, format("Failed reading parquet data; source= %s; %s", parquetDataSourceId, e.getMessage()), e); + else { + blocks[columnIndex] = page.getBlock(sourceColumn); + sourceColumn++; } + } + return new Page(batchSize, blocks); + } - loaded = true; - return block; + static TrinoException handleException(ParquetDataSourceId dataSourceId, Exception exception) + { + if (exception instanceof TrinoException) { + return (TrinoException) exception; + } + if (exception instanceof ParquetCorruptionException) { + return new TrinoException(HIVE_BAD_DATA, exception); } + return new TrinoException(HIVE_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); } private static Block getRowIndexColumn(long baseIndex, int size) @@ -225,10 +177,4 @@ private static Block getRowIndexColumn(long baseIndex, int size) } return new LongArrayBlock(size, Optional.empty(), rowIndices); } - - @Override - public Metrics getMetrics() - { - return new Metrics(parquetReader.getCodecMetrics()); - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index 13a26ad990f5..06b954c279b4 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -19,13 +19,14 @@ import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoFileSystemFactory; import io.trino.filesystem.TrinoInputFile; -import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.MetadataReader; import io.trino.parquet.reader.ParquetReader; +import io.trino.parquet.reader.ParquetReaderColumn; import io.trino.parquet.reader.TrinoColumnIndexStore; import io.trino.plugin.hive.AcidInfo; import io.trino.plugin.hive.FileFormatDataSourceStats; @@ -41,7 +42,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; -import io.trino.spi.type.Type; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hdfs.BlockMissingException; @@ -72,6 +72,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.parquet.ParquetTypeUtils.getColumnIO; @@ -80,6 +81,7 @@ import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; import static io.trino.parquet.predicate.PredicateUtils.buildPredicate; import static io.trino.parquet.predicate.PredicateUtils.predicateMatches; +import static io.trino.parquet.reader.ParquetReaderColumn.getParquetReaderFields; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; @@ -91,6 +93,7 @@ import static io.trino.plugin.hive.HiveSessionProperties.isParquetUseColumnIndex; import static io.trino.plugin.hive.HiveSessionProperties.isUseParquetColumnNames; import static io.trino.plugin.hive.parquet.HiveParquetColumnIOConverter.constructField; +import static io.trino.plugin.hive.parquet.ParquetPageSource.handleException; import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; import static io.trino.spi.type.BigintType.BIGINT; import static java.lang.String.format; @@ -197,7 +200,6 @@ public static ReaderPageSource createPageSource( MessageType fileSchema; MessageType requestedSchema; MessageColumnIO messageColumn; - ParquetReader parquetReader; ParquetDataSource dataSource = null; try { dataSource = new TrinoParquetDataSource(inputFile, options, stats); @@ -250,47 +252,23 @@ && predicateMatches(parquetPredicate, block, dataSource, descriptorsByPath, parq .map(HiveColumnHandle.class::cast) .collect(toUnmodifiableList())) .orElse(columns); + List parquetReaderColumns = createParquetReaderColumns(baseColumns, fileSchema, messageColumn, useColumnNames); - for (HiveColumnHandle column : baseColumns) { - checkArgument(column == PARQUET_ROW_INDEX_COLUMN || column.getColumnType() == REGULAR, "column type must be REGULAR: %s", column); - } - - ImmutableList.Builder trinoTypes = ImmutableList.builder(); - ImmutableList.Builder> internalFieldsBuilder = ImmutableList.builder(); - ImmutableList.Builder rowIndexColumns = ImmutableList.builder(); - for (HiveColumnHandle column : baseColumns) { - trinoTypes.add(column.getBaseType()); - rowIndexColumns.add(column == PARQUET_ROW_INDEX_COLUMN); - if (column == PARQUET_ROW_INDEX_COLUMN) { - internalFieldsBuilder.add(Optional.empty()); - } - else { - internalFieldsBuilder.add(Optional.ofNullable(getParquetType(column, fileSchema, useColumnNames)) - .flatMap(field -> { - String columnName = useColumnNames ? column.getBaseColumnName() : fileSchema.getFields().get(column.getBaseHiveColumnIndex()).getName(); - return constructField(column.getBaseType(), lookupColumnByName(messageColumn, columnName)); - })); - } - } - - List> internalFields = internalFieldsBuilder.build(); - parquetReader = new ParquetReader( - Optional.ofNullable(fileMetaData.getCreatedBy()), - internalFields, - blocks.build(), - blockStarts.build(), - dataSource, - timeZone, - newSimpleAggregatedMemoryContext(), - options, - Optional.of(parquetPredicate), - columnIndexes.build()); - + ParquetDataSourceId dataSourceId = dataSource.getId(); ConnectorPageSource parquetPageSource = new ParquetPageSource( - parquetReader, - trinoTypes.build(), - rowIndexColumns.build(), - internalFields); + new ParquetReader( + Optional.ofNullable(fileMetaData.getCreatedBy()), + getParquetReaderFields(parquetReaderColumns), + blocks.build(), + blockStarts.build(), + dataSource, + timeZone, + newSimpleAggregatedMemoryContext(), + options, + exception -> handleException(dataSourceId, exception), + Optional.of(parquetPredicate), + columnIndexes.build()), + parquetReaderColumns); return new ReaderPageSource(parquetPageSource, readerProjections); } catch (Exception e) { @@ -445,4 +423,25 @@ private static org.apache.parquet.schema.Type getParquetType(HiveColumnHandle co } return null; } + + private static List createParquetReaderColumns(List baseColumns, MessageType fileSchema, MessageColumnIO messageColumn, boolean useColumnNames) + { + for (HiveColumnHandle column : baseColumns) { + checkArgument(column == PARQUET_ROW_INDEX_COLUMN || column.getColumnType() == REGULAR, "column type must be REGULAR: %s", column); + } + + return baseColumns.stream() + .map(column -> { + boolean isRowIndexColumn = column == PARQUET_ROW_INDEX_COLUMN; + return new ParquetReaderColumn( + column.getBaseType(), + isRowIndexColumn ? Optional.empty() : Optional.ofNullable(getParquetType(column, fileSchema, useColumnNames)) + .flatMap(field -> { + String columnName = useColumnNames ? column.getBaseColumnName() : fileSchema.getFields().get(column.getBaseHiveColumnIndex()).getName(); + return constructField(column.getBaseType(), lookupColumnByName(messageColumn, columnName)); + }), + isRowIndexColumn); + }) + .collect(toImmutableList()); + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 57370e422807..7fb9dd54e506 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -35,13 +35,14 @@ import io.trino.orc.TupleDomainOrcPredicate; import io.trino.orc.TupleDomainOrcPredicate.TupleDomainOrcPredicateBuilder; import io.trino.orc.metadata.OrcType; -import io.trino.parquet.Field; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.MetadataReader; import io.trino.parquet.reader.ParquetReader; +import io.trino.parquet.reader.ParquetReaderColumn; import io.trino.plugin.base.classloader.ClassLoaderSafeUpdatablePageSource; import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.ReaderColumns; @@ -139,6 +140,7 @@ import static io.trino.parquet.ParquetTypeUtils.getDescriptors; import static io.trino.parquet.predicate.PredicateUtils.buildPredicate; import static io.trino.parquet.predicate.PredicateUtils.predicateMatches; +import static io.trino.parquet.reader.ParquetReaderColumn.getParquetReaderFields; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_DATA; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_SPEC_ID; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_BAD_DATA; @@ -979,9 +981,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( ConstantPopulatingPageSource.Builder constantPopulatingPageSourceBuilder = ConstantPopulatingPageSource.builder(); int parquetSourceChannel = 0; - ImmutableList.Builder trinoTypes = ImmutableList.builder(); - ImmutableList.Builder> internalFieldsBuilder = ImmutableList.builder(); - ImmutableList.Builder rowIndexChannels = ImmutableList.builder(); + ImmutableList.Builder parquetReaderColumnBuilder = ImmutableList.builder(); for (int columnIndex = 0; columnIndex < readColumns.size(); columnIndex++) { IcebergColumnHandle column = readColumns.get(columnIndex); if (column.isIsDeletedColumn()) { @@ -1001,16 +1001,12 @@ else if (column.isFileModifiedTimeColumn()) { } else if (column.isUpdateRowIdColumn() || column.isMergeRowIdColumn()) { // $row_id is a composite of multiple physical columns, it is assembled by the IcebergPageSource - trinoTypes.add(column.getType()); - internalFieldsBuilder.add(Optional.empty()); - rowIndexChannels.add(false); + parquetReaderColumnBuilder.add(new ParquetReaderColumn(column.getType(), Optional.empty(), false)); constantPopulatingPageSourceBuilder.addDelegateColumn(parquetSourceChannel); parquetSourceChannel++; } else if (column.isRowPositionColumn()) { - trinoTypes.add(BIGINT); - internalFieldsBuilder.add(Optional.empty()); - rowIndexChannels.add(true); + parquetReaderColumnBuilder.add(new ParquetReaderColumn(BIGINT, Optional.empty(), true)); constantPopulatingPageSourceBuilder.addDelegateColumn(parquetSourceChannel); parquetSourceChannel++; } @@ -1021,18 +1017,19 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), utf8Slice(partitionData))); } else { - rowIndexChannels.add(false); org.apache.parquet.schema.Type parquetField = parquetFields.get(columnIndex); Type trinoType = column.getBaseType(); - trinoTypes.add(trinoType); if (parquetField == null) { - internalFieldsBuilder.add(Optional.empty()); + parquetReaderColumnBuilder.add(new ParquetReaderColumn(trinoType, Optional.empty(), false)); } else { // The top level columns are already mapped by name/id appropriately. ColumnIO columnIO = messageColumnIO.getChild(parquetField.getName()); - internalFieldsBuilder.add(IcebergParquetColumnIOConverter.constructField(new FieldContext(trinoType, column.getColumnIdentity()), columnIO)); + parquetReaderColumnBuilder.add(new ParquetReaderColumn( + trinoType, + IcebergParquetColumnIOConverter.constructField(new FieldContext(trinoType, column.getColumnIdentity()), columnIO), + false)); } constantPopulatingPageSourceBuilder.addDelegateColumn(parquetSourceChannel); @@ -1040,21 +1037,21 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { } } - List> internalFields = internalFieldsBuilder.build(); + List parquetReaderColumns = parquetReaderColumnBuilder.build(); + ParquetDataSourceId dataSourceId = dataSource.getId(); ParquetReader parquetReader = new ParquetReader( Optional.ofNullable(fileMetaData.getCreatedBy()), - internalFields, + getParquetReaderFields(parquetReaderColumns), blocks, blockStarts.build(), dataSource, UTC, memoryContext, options, - Optional.empty()); - + exception -> handleException(dataSourceId, exception)); return new ReaderPageSourceWithRowPositions( new ReaderPageSource( - constantPopulatingPageSourceBuilder.build(new ParquetPageSource(parquetReader, trinoTypes.build(), rowIndexChannels.build(), internalFields)), + constantPopulatingPageSourceBuilder.build(new ParquetPageSource(parquetReader, parquetReaderColumns)), columnProjections), startRowPosition, endRowPosition); @@ -1353,6 +1350,17 @@ private static TrinoException handleException(OrcDataSourceId dataSourceId, Exce return new TrinoException(ICEBERG_CURSOR_ERROR, format("Failed to read ORC file: %s", dataSourceId), exception); } + private static TrinoException handleException(ParquetDataSourceId dataSourceId, Exception exception) + { + if (exception instanceof TrinoException) { + return (TrinoException) exception; + } + if (exception instanceof ParquetCorruptionException) { + return new TrinoException(ICEBERG_BAD_DATA, exception); + } + return new TrinoException(ICEBERG_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); + } + public static final class ReaderPageSourceWithRowPositions { private final ReaderPageSource readerPageSource; From 7bb5cbac6ca9d8df7ee551754b0b66ad5c77206f Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Fri, 12 Aug 2022 12:42:30 +0530 Subject: [PATCH 3/4] Move HiveParquetColumnIOConverter#constructField to trino-parquet The logic here does not have dependency on hive and it is needed here to allow ParquetWriter to create ParquetReader for the verification of the written file. --- .../io/trino/parquet/ParquetTypeUtils.java | 55 +++++++++++ .../parquet/HiveParquetColumnIOConverter.java | 92 ------------------- .../parquet/ParquetPageSourceFactory.java | 2 +- 3 files changed, 56 insertions(+), 93 deletions(-) delete mode 100644 plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/HiveParquetColumnIOConverter.java diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java index 26fd692d1c73..7f82fd3f2be0 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java @@ -13,7 +13,12 @@ */ package io.trino.parquet; +import com.google.common.collect.ImmutableList; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.DecimalType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; import org.apache.parquet.io.ColumnIO; @@ -31,10 +36,14 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static org.apache.parquet.io.ColumnIOUtil.columnDefinitionLevel; +import static org.apache.parquet.io.ColumnIOUtil.columnRepetitionLevel; +import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; public final class ParquetTypeUtils @@ -259,4 +268,50 @@ public static long getShortDecimalValue(byte[] bytes, int startOffset, int lengt return value; } + + public static Optional constructField(Type type, ColumnIO columnIO) + { + if (columnIO == null) { + return Optional.empty(); + } + boolean required = columnIO.getType().getRepetition() != OPTIONAL; + int repetitionLevel = columnRepetitionLevel(columnIO); + int definitionLevel = columnDefinitionLevel(columnIO); + if (type instanceof RowType rowType) { + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + ImmutableList.Builder> fieldsBuilder = ImmutableList.builder(); + List fields = rowType.getFields(); + boolean structHasParameters = false; + for (RowType.Field rowField : fields) { + String name = rowField.getName().orElseThrow().toLowerCase(Locale.ENGLISH); + Optional field = constructField(rowField.getType(), lookupColumnByName(groupColumnIO, name)); + structHasParameters |= field.isPresent(); + fieldsBuilder.add(field); + } + if (structHasParameters) { + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, fieldsBuilder.build())); + } + return Optional.empty(); + } + if (type instanceof MapType mapType) { + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); + if (keyValueColumnIO.getChildrenCount() != 2) { + return Optional.empty(); + } + Optional keyField = constructField(mapType.getKeyType(), keyValueColumnIO.getChild(0)); + Optional valueField = constructField(mapType.getValueType(), keyValueColumnIO.getChild(1)); + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField))); + } + if (type instanceof ArrayType arrayType) { + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + if (groupColumnIO.getChildrenCount() != 1) { + return Optional.empty(); + } + Optional field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0))); + return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field))); + } + PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO; + return Optional.of(new PrimitiveField(type, required, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId())); + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/HiveParquetColumnIOConverter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/HiveParquetColumnIOConverter.java deleted file mode 100644 index 7487639861a7..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/HiveParquetColumnIOConverter.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive.parquet; - -import com.google.common.collect.ImmutableList; -import io.trino.parquet.Field; -import io.trino.parquet.GroupField; -import io.trino.parquet.PrimitiveField; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.MapType; -import io.trino.spi.type.RowType; -import io.trino.spi.type.Type; -import org.apache.parquet.io.ColumnIO; -import org.apache.parquet.io.GroupColumnIO; -import org.apache.parquet.io.PrimitiveColumnIO; - -import java.util.List; -import java.util.Locale; -import java.util.Optional; - -import static io.trino.parquet.ParquetTypeUtils.getArrayElementColumn; -import static io.trino.parquet.ParquetTypeUtils.getMapKeyValueColumn; -import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; -import static org.apache.parquet.io.ColumnIOUtil.columnDefinitionLevel; -import static org.apache.parquet.io.ColumnIOUtil.columnRepetitionLevel; -import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; - -public final class HiveParquetColumnIOConverter -{ - private HiveParquetColumnIOConverter() {} - - public static Optional constructField(Type type, ColumnIO columnIO) - { - if (columnIO == null) { - return Optional.empty(); - } - boolean required = columnIO.getType().getRepetition() != OPTIONAL; - int repetitionLevel = columnRepetitionLevel(columnIO); - int definitionLevel = columnDefinitionLevel(columnIO); - if (type instanceof RowType) { - RowType rowType = (RowType) type; - GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - ImmutableList.Builder> fieldsBuilder = ImmutableList.builder(); - List fields = rowType.getFields(); - boolean structHasParameters = false; - for (int i = 0; i < fields.size(); i++) { - RowType.Field rowField = fields.get(i); - String name = rowField.getName().orElseThrow().toLowerCase(Locale.ENGLISH); - Optional field = constructField(rowField.getType(), lookupColumnByName(groupColumnIO, name)); - structHasParameters |= field.isPresent(); - fieldsBuilder.add(field); - } - if (structHasParameters) { - return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, fieldsBuilder.build())); - } - return Optional.empty(); - } - if (type instanceof MapType) { - MapType mapType = (MapType) type; - GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO); - if (keyValueColumnIO.getChildrenCount() != 2) { - return Optional.empty(); - } - Optional keyField = constructField(mapType.getKeyType(), keyValueColumnIO.getChild(0)); - Optional valueField = constructField(mapType.getValueType(), keyValueColumnIO.getChild(1)); - return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField))); - } - if (type instanceof ArrayType) { - ArrayType arrayType = (ArrayType) type; - GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; - if (groupColumnIO.getChildrenCount() != 1) { - return Optional.empty(); - } - Optional field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0))); - return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field))); - } - PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO; - return Optional.of(new PrimitiveField(type, required, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId())); - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index 06b954c279b4..fcdf69b71b45 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -75,6 +75,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.trino.parquet.ParquetTypeUtils.constructField; import static io.trino.parquet.ParquetTypeUtils.getColumnIO; import static io.trino.parquet.ParquetTypeUtils.getDescriptors; import static io.trino.parquet.ParquetTypeUtils.getParquetTypeByName; @@ -92,7 +93,6 @@ import static io.trino.plugin.hive.HiveSessionProperties.isParquetIgnoreStatistics; import static io.trino.plugin.hive.HiveSessionProperties.isParquetUseColumnIndex; import static io.trino.plugin.hive.HiveSessionProperties.isUseParquetColumnNames; -import static io.trino.plugin.hive.parquet.HiveParquetColumnIOConverter.constructField; import static io.trino.plugin.hive.parquet.ParquetPageSource.handleException; import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; import static io.trino.spi.type.BigintType.BIGINT; From a41d95fbfc1d58b9af064f69e49693f71d58efb9 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Mon, 4 Jul 2022 10:24:47 +0530 Subject: [PATCH 4/4] Implement verification for optimized parquet writer Implements verification of file footer, row count, nulls count and checksum of columns. Added a config parquet.optimized-writer.validation-percentage and session property in hive connector to control the percentage of written files that will be verified. --- docs/src/main/sphinx/connector/hive.rst | 6 + .../parquet/ColumnStatisticsValidation.java | 177 +++++ .../parquet/ParquetCorruptionException.java | 10 + .../trino/parquet/ParquetValidationUtils.java | 8 + .../trino/parquet/ParquetWriteValidation.java | 664 ++++++++++++++++++ .../java/io/trino/parquet/ValidationHash.java | 144 ++++ .../trino/parquet/reader/MetadataReader.java | 26 +- .../trino/parquet/reader/ParquetReader.java | 67 +- .../trino/parquet/writer/ParquetWriter.java | 103 ++- .../plugin/deltalake/DeltaLakeMergeSink.java | 12 +- .../plugin/deltalake/DeltaLakePageSink.java | 2 + .../DeltaLakePageSourceProvider.java | 3 +- .../DeltaLakeUpdatablePageSource.java | 3 +- .../checkpoint/CheckpointEntryIterator.java | 3 +- .../plugin/hive/HiveSessionProperties.java | 25 + .../hive/parquet/ParquetFileWriter.java | 34 +- .../parquet/ParquetFileWriterFactory.java | 42 +- .../hive/parquet/ParquetPageSource.java | 2 +- .../parquet/ParquetPageSourceFactory.java | 12 +- .../hive/parquet/ParquetWriterConfig.java | 19 + .../plugin/hive/TestHiveFileFormats.java | 8 +- .../hive/benchmark/StandardFileFormats.java | 3 +- .../plugin/hive/parquet/ParquetTester.java | 23 +- .../hive/parquet/TestParquetWriterConfig.java | 9 +- .../iceberg/IcebergFileWriterFactory.java | 1 + .../iceberg/IcebergPageSourceProvider.java | 2 +- .../iceberg/IcebergParquetFileWriter.java | 3 + 27 files changed, 1377 insertions(+), 34 deletions(-) create mode 100644 lib/trino-parquet/src/main/java/io/trino/parquet/ColumnStatisticsValidation.java create mode 100644 lib/trino-parquet/src/main/java/io/trino/parquet/ParquetWriteValidation.java create mode 100644 lib/trino-parquet/src/main/java/io/trino/parquet/ValidationHash.java diff --git a/docs/src/main/sphinx/connector/hive.rst b/docs/src/main/sphinx/connector/hive.rst index 439d6fd54c09..a1c6a35b2170 100644 --- a/docs/src/main/sphinx/connector/hive.rst +++ b/docs/src/main/sphinx/connector/hive.rst @@ -484,6 +484,12 @@ with Parquet files performed by the Hive connector. definition. The equivalent catalog session property is ``parquet_use_column_names``. - ``true`` + * - ``parquet.optimized-writer.validation-percentage`` + - Percentage of parquet files to validate after write by re-reading the whole file + when ``parquet.experimental-optimized-writer.enabled`` is set to ``true``. + The equivalent catalog session property is ``parquet_optimized_writer_validation_percentage``. + Validation can be turned off by setting this property to ``0``. + - ``5`` Metastore configuration properties diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ColumnStatisticsValidation.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ColumnStatisticsValidation.java new file mode 100644 index 000000000000..71e29b6c1134 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ColumnStatisticsValidation.java @@ -0,0 +1,177 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.ColumnarArray; +import io.trino.spi.block.ColumnarMap; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.block.ColumnarArray.toColumnarArray; +import static io.trino.spi.block.ColumnarMap.toColumnarMap; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +class ColumnStatisticsValidation +{ + private final Type type; + private final List fieldBuilders; + + private long valuesCount; + private long nonLeafValuesCount; + + public ColumnStatisticsValidation(Type type) + { + this.type = requireNonNull(type, "type is null"); + this.fieldBuilders = type.getTypeParameters().stream() + .map(ColumnStatisticsValidation::new) + .collect(toImmutableList()); + } + + public void addBlock(Block block) + { + addBlock(block, new ColumnStatistics(0, 0)); + } + + public List build() + { + if (fieldBuilders.isEmpty()) { + return ImmutableList.of(new ColumnStatistics(valuesCount, nonLeafValuesCount)); + } + return fieldBuilders.stream() + .flatMap(builder -> builder.build().stream()) + .collect(toImmutableList()); + } + + private void addBlock(Block block, ColumnStatistics columnStatistics) + { + if (fieldBuilders.isEmpty()) { + addPrimitiveBlock(block); + valuesCount += columnStatistics.valuesCount(); + nonLeafValuesCount += columnStatistics.nonLeafValuesCount(); + return; + } + + List fields; + ColumnStatistics mergedColumnStatistics; + if (type instanceof ArrayType) { + ColumnarArray columnarArray = toColumnarArray(block); + fields = ImmutableList.of(columnarArray.getElementsBlock()); + mergedColumnStatistics = columnStatistics.merge(addArrayBlock(columnarArray)); + } + else if (type instanceof MapType) { + ColumnarMap columnarMap = toColumnarMap(block); + fields = ImmutableList.of(columnarMap.getKeysBlock(), columnarMap.getValuesBlock()); + mergedColumnStatistics = columnStatistics.merge(addMapBlock(columnarMap)); + } + else if (type instanceof RowType) { + ColumnarRow columnarRow = ColumnarRow.toColumnarRow(block); + ImmutableList.Builder fieldsBuilder = ImmutableList.builder(); + for (int index = 0; index < columnarRow.getFieldCount(); index++) { + fieldsBuilder.add(columnarRow.getField(index)); + } + fields = fieldsBuilder.build(); + mergedColumnStatistics = columnStatistics.merge(addRowBlock(columnarRow)); + } + else { + throw new TrinoException(NOT_SUPPORTED, format("Unsupported type: %s", type)); + } + + for (int i = 0; i < fieldBuilders.size(); i++) { + fieldBuilders.get(i).addBlock(fields.get(i), mergedColumnStatistics); + } + } + + private void addPrimitiveBlock(Block block) + { + valuesCount += block.getPositionCount(); + if (!block.mayHaveNull()) { + return; + } + int nullsCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nullsCount += block.isNull(position) ? 1 : 0; + } + nonLeafValuesCount += nullsCount; + } + + private static ColumnStatistics addMapBlock(ColumnarMap block) + { + if (!block.mayHaveNull()) { + int emptyEntriesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + emptyEntriesCount += block.getEntryCount(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(emptyEntriesCount, emptyEntriesCount); + } + int nonLeafValuesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nonLeafValuesCount += block.isNull(position) || block.getEntryCount(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(nonLeafValuesCount, nonLeafValuesCount); + } + + private static ColumnStatistics addArrayBlock(ColumnarArray block) + { + if (!block.mayHaveNull()) { + int emptyEntriesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + emptyEntriesCount += block.getLength(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(emptyEntriesCount, emptyEntriesCount); + } + int nonLeafValuesCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nonLeafValuesCount += block.isNull(position) || block.getLength(position) == 0 ? 1 : 0; + } + return new ColumnStatistics(nonLeafValuesCount, nonLeafValuesCount); + } + + private static ColumnStatistics addRowBlock(ColumnarRow block) + { + if (!block.mayHaveNull()) { + return new ColumnStatistics(0, 0); + } + int nullsCount = 0; + for (int position = 0; position < block.getPositionCount(); position++) { + nullsCount += block.isNull(position) ? 1 : 0; + } + return new ColumnStatistics(nullsCount, nullsCount); + } + + /** + * @param valuesCount Count of values for a column field, including nulls, empty and defined values. + * @param nonLeafValuesCount Count of non-leaf values for a column field, this is nulls count for primitives + * and count of values below the max definition level for nested types + */ + record ColumnStatistics(long valuesCount, long nonLeafValuesCount) + { + ColumnStatistics merge(ColumnStatistics other) + { + return new ColumnStatistics( + valuesCount + other.valuesCount(), + nonLeafValuesCount + other.nonLeafValuesCount()); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java index 3f3498177c30..719190d336ea 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetCorruptionException.java @@ -34,4 +34,14 @@ public ParquetCorruptionException(Throwable cause, String messageFormat, Object. { super(format(messageFormat, args), cause); } + + public ParquetCorruptionException(ParquetDataSourceId dataSourceId, String messageFormat, Object... args) + { + super(formatMessage(dataSourceId, messageFormat, args)); + } + + private static String formatMessage(ParquetDataSourceId dataSourceId, String messageFormat, Object[] args) + { + return "Malformed Parquet file. " + format(messageFormat, args) + " [" + dataSourceId + "]"; + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java index ee56f75aba43..c194405c0029 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetValidationUtils.java @@ -26,4 +26,12 @@ public static void validateParquet(boolean condition, String formatString, Objec throw new ParquetCorruptionException(format(formatString, args)); } } + + public static void validateParquet(boolean condition, ParquetDataSourceId dataSourceId, String formatString, Object... args) + throws ParquetCorruptionException + { + if (!condition) { + throw new ParquetCorruptionException(dataSourceId, formatString, args); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetWriteValidation.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetWriteValidation.java new file mode 100644 index 000000000000..65eeab7ee484 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetWriteValidation.java @@ -0,0 +1,664 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.SizeOf; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.airlift.slice.XxHash64; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.type.Type; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.format.ColumnChunk; +import org.apache.parquet.format.ColumnMetaData; +import org.apache.parquet.format.RowGroup; +import org.apache.parquet.format.converter.ParquetMetadataConverter; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.parquet.internal.hadoop.metadata.IndexReference; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.openjdk.jol.info.ClassLayout; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.slice.SizeOf.SIZE_OF_INT; +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.trino.parquet.ColumnStatisticsValidation.ColumnStatistics; +import static io.trino.parquet.ParquetValidationUtils.validateParquet; +import static io.trino.parquet.ParquetWriteValidation.IndexReferenceValidation.fromIndexReference; +import static java.util.Objects.requireNonNull; + +public class ParquetWriteValidation +{ + private static final ParquetMetadataConverter METADATA_CONVERTER = new ParquetMetadataConverter(); + + private final String createdBy; + private final Optional timeZoneId; + private final List columns; + private final List rowGroups; + private final WriteChecksum checksum; + private final List types; + private final List columnNames; + + private ParquetWriteValidation( + String createdBy, + Optional timeZoneId, + List columns, + List rowGroups, + WriteChecksum checksum, + List types, + List columnNames) + { + this.createdBy = requireNonNull(createdBy, "createdBy is null"); + checkArgument(!createdBy.isEmpty(), "createdBy is empty"); + this.timeZoneId = requireNonNull(timeZoneId, "timeZoneId is null"); + this.columns = requireNonNull(columns, "columnPaths is null"); + this.rowGroups = requireNonNull(rowGroups, "rowGroups is null"); + this.checksum = requireNonNull(checksum, "checksum is null"); + this.types = requireNonNull(types, "types is null"); + this.columnNames = requireNonNull(columnNames, "columnNames is null"); + } + + public String getCreatedBy() + { + return createdBy; + } + + public List getTypes() + { + return types; + } + + public List getColumnNames() + { + return columnNames; + } + + public void validateTimeZone(ParquetDataSourceId dataSourceId, Optional actualTimeZoneId) + throws ParquetCorruptionException + { + validateParquet( + timeZoneId.equals(actualTimeZoneId), + dataSourceId, + "Found unexpected time zone %s, expected %s", + actualTimeZoneId, + timeZoneId); + } + + public void validateColumns(ParquetDataSourceId dataSourceId, MessageType schema) + throws ParquetCorruptionException + { + List actualColumns = schema.getColumns(); + validateParquet( + actualColumns.size() == columns.size(), + dataSourceId, + "Found columns %s, expected %s", + actualColumns, + columns); + for (int columnIndex = 0; columnIndex < columns.size(); columnIndex++) { + validateColumnDescriptorsSame(actualColumns.get(columnIndex), columns.get(columnIndex), dataSourceId); + } + } + + public void validateBlocksMetadata(ParquetDataSourceId dataSourceId, List blocksMetaData) + throws ParquetCorruptionException + { + validateParquet( + blocksMetaData.size() == rowGroups.size(), + dataSourceId, + "Number of row groups %d did not match %d", + blocksMetaData.size(), + rowGroups.size()); + for (int rowGroupIndex = 0; rowGroupIndex < blocksMetaData.size(); rowGroupIndex++) { + BlockMetaData block = blocksMetaData.get(rowGroupIndex); + RowGroup rowGroup = rowGroups.get(rowGroupIndex); + validateParquet( + block.getRowCount() == rowGroup.getNum_rows(), + dataSourceId, + "Number of rows %d in row group %d did not match %d", + block.getRowCount(), + rowGroupIndex, + rowGroup.getNum_rows()); + + List columnChunkMetaData = block.getColumns(); + validateParquet( + columnChunkMetaData.size() == rowGroup.getColumnsSize(), + dataSourceId, + "Number of columns %d in row group %d did not match %d", + columnChunkMetaData.size(), + rowGroupIndex, + rowGroup.getColumnsSize()); + + for (int columnIndex = 0; columnIndex < columnChunkMetaData.size(); columnIndex++) { + ColumnChunkMetaData actualColumnMetadata = columnChunkMetaData.get(columnIndex); + ColumnChunk columnChunk = rowGroup.getColumns().get(columnIndex); + ColumnMetaData expectedColumnMetadata = columnChunk.getMeta_data(); + verifyColumnMetadataMatch( + actualColumnMetadata.getCodec().getParquetCompressionCodec().equals(expectedColumnMetadata.getCodec()), + "Compression codec", + actualColumnMetadata.getCodec(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getCodec()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName().equals(METADATA_CONVERTER.getPrimitive(expectedColumnMetadata.getType())), + "Type", + actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getType()); + + verifyColumnMetadataMatch( + areEncodingsSame(actualColumnMetadata.getEncodings(), expectedColumnMetadata.getEncodings()), + "Encodings", + actualColumnMetadata.getEncodings(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getEncodings()); + + verifyColumnMetadataMatch( + areStatisticsSame(actualColumnMetadata.getStatistics(), expectedColumnMetadata.getStatistics()), + "Statistics", + actualColumnMetadata.getStatistics(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getStatistics()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getFirstDataPageOffset() == expectedColumnMetadata.getData_page_offset(), + "Data page offset", + actualColumnMetadata.getFirstDataPageOffset(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getData_page_offset()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getDictionaryPageOffset() == expectedColumnMetadata.getDictionary_page_offset(), + "Dictionary page offset", + actualColumnMetadata.getDictionaryPageOffset(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getDictionary_page_offset()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getValueCount() == expectedColumnMetadata.getNum_values(), + "Value count", + actualColumnMetadata.getValueCount(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getNum_values()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getTotalUncompressedSize() == expectedColumnMetadata.getTotal_uncompressed_size(), + "Total uncompressed size", + actualColumnMetadata.getTotalUncompressedSize(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getTotal_uncompressed_size()); + + verifyColumnMetadataMatch( + actualColumnMetadata.getTotalSize() == expectedColumnMetadata.getTotal_compressed_size(), + "Total size", + actualColumnMetadata.getTotalSize(), + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnMetadata.getTotal_compressed_size()); + + IndexReferenceValidation expectedColumnIndexReference = new IndexReferenceValidation(columnChunk.getColumn_index_offset(), columnChunk.getColumn_index_length()); + IndexReference actualColumnIndexReference = actualColumnMetadata.getColumnIndexReference(); + verifyColumnMetadataMatch( + actualColumnIndexReference == null || fromIndexReference(actualColumnMetadata.getColumnIndexReference()).equals(expectedColumnIndexReference), + "Column index reference", + actualColumnIndexReference, + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedColumnIndexReference); + + IndexReferenceValidation expectedOffsetIndexReference = new IndexReferenceValidation(columnChunk.getOffset_index_offset(), columnChunk.getOffset_index_length()); + IndexReference actualOffsetIndexReference = actualColumnMetadata.getOffsetIndexReference(); + verifyColumnMetadataMatch( + actualOffsetIndexReference == null || fromIndexReference(actualOffsetIndexReference).equals(expectedOffsetIndexReference), + "Offset index reference", + actualOffsetIndexReference, + actualColumnMetadata.getPath(), + rowGroupIndex, + dataSourceId, + expectedOffsetIndexReference); + } + } + } + + public void validateChecksum(ParquetDataSourceId dataSourceId, WriteChecksum actualChecksum) + throws ParquetCorruptionException + { + validateParquet( + checksum.totalRowCount() == actualChecksum.totalRowCount(), + dataSourceId, + "Write validation failed: Expected row count %d, found %d", + checksum.totalRowCount(), + actualChecksum.totalRowCount()); + + List columnHashes = actualChecksum.columnHashes(); + for (int columnIndex = 0; columnIndex < columnHashes.size(); columnIndex++) { + long expectedHash = checksum.columnHashes().get(columnIndex); + validateParquet( + expectedHash == columnHashes.get(columnIndex), + dataSourceId, + "Invalid checksum for column %s: Expected hash %d, found %d", + columnIndex, + expectedHash, + columnHashes.get(columnIndex)); + } + } + + public record WriteChecksum(long totalRowCount, List columnHashes) + { + public WriteChecksum(long totalRowCount, List columnHashes) + { + this.totalRowCount = totalRowCount; + this.columnHashes = ImmutableList.copyOf(requireNonNull(columnHashes, "columnHashes is null")); + } + } + + public static class WriteChecksumBuilder + { + private final List validationHashes; + private final List columnHashes; + private final byte[] longBuffer = new byte[Long.BYTES]; + private final Slice longSlice = Slices.wrappedBuffer(longBuffer); + + private long totalRowCount; + + private WriteChecksumBuilder(List types) + { + this.validationHashes = requireNonNull(types, "types is null").stream() + .map(ValidationHash::createValidationHash) + .collect(toImmutableList()); + + ImmutableList.Builder columnHashes = ImmutableList.builder(); + for (Type ignored : types) { + columnHashes.add(new XxHash64()); + } + this.columnHashes = columnHashes.build(); + } + + public static WriteChecksumBuilder createWriteChecksumBuilder(List readTypes) + { + return new WriteChecksumBuilder(readTypes); + } + + public void addPage(Page page) + { + requireNonNull(page, "page is null"); + checkArgument( + page.getChannelCount() == columnHashes.size(), + "Invalid page: page channels count %s did not match columns count %s", + page.getChannelCount(), + columnHashes.size()); + + for (int channel = 0; channel < columnHashes.size(); channel++) { + ValidationHash validationHash = validationHashes.get(channel); + Block block = page.getBlock(channel); + XxHash64 xxHash64 = columnHashes.get(channel); + for (int position = 0; position < block.getPositionCount(); position++) { + long hash = validationHash.hash(block, position); + longSlice.setLong(0, hash); + xxHash64.update(longBuffer); + } + } + totalRowCount += page.getPositionCount(); + } + + public WriteChecksum build() + { + return new WriteChecksum( + totalRowCount, + columnHashes.stream() + .map(XxHash64::hash) + .collect(toImmutableList())); + } + } + + public void validateRowGroupStatistics(ParquetDataSourceId dataSourceId, BlockMetaData blockMetaData, List actualColumnStatistics) + throws ParquetCorruptionException + { + List columnChunks = blockMetaData.getColumns(); + checkArgument( + columnChunks.size() == actualColumnStatistics.size(), + "Column chunk metadata count %s did not match column fields count %s", + columnChunks.size(), + actualColumnStatistics.size()); + + for (int columnIndex = 0; columnIndex < columnChunks.size(); columnIndex++) { + ColumnChunkMetaData columnMetaData = columnChunks.get(columnIndex); + ColumnStatistics columnStatistics = actualColumnStatistics.get(columnIndex); + long expectedValuesCount = columnMetaData.getValueCount(); + validateParquet( + expectedValuesCount == columnStatistics.valuesCount(), + dataSourceId, + "Invalid values count for column %s: Expected %d, found %d", + columnIndex, + expectedValuesCount, + columnStatistics.valuesCount()); + + Statistics parquetStatistics = columnMetaData.getStatistics(); + if (parquetStatistics.isNumNullsSet()) { + long expectedNullsCount = parquetStatistics.getNumNulls(); + validateParquet( + expectedNullsCount == columnStatistics.nonLeafValuesCount(), + dataSourceId, + "Invalid nulls count for column %s: Expected %d, found %d", + columnIndex, + expectedNullsCount, + columnStatistics.nonLeafValuesCount()); + } + } + } + + public static class StatisticsValidation + { + private final List types; + private List columnStatisticsValidations; + + private StatisticsValidation(List types) + { + this.types = requireNonNull(types, "types is null"); + this.columnStatisticsValidations = types.stream() + .map(ColumnStatisticsValidation::new) + .collect(toImmutableList()); + } + + public static StatisticsValidation createStatisticsValidationBuilder(List readTypes) + { + return new StatisticsValidation(readTypes); + } + + public void addPage(Page page) + { + requireNonNull(page, "page is null"); + checkArgument( + page.getChannelCount() == columnStatisticsValidations.size(), + "Invalid page: page channels count %s did not match columns count %s", + page.getChannelCount(), + columnStatisticsValidations.size()); + + for (int channel = 0; channel < columnStatisticsValidations.size(); channel++) { + ColumnStatisticsValidation columnStatisticsValidation = columnStatisticsValidations.get(channel); + columnStatisticsValidation.addBlock(page.getBlock(channel)); + } + } + + public void reset() + { + this.columnStatisticsValidations = types.stream() + .map(ColumnStatisticsValidation::new) + .collect(toImmutableList()); + } + + public List build() + { + return this.columnStatisticsValidations.stream() + .flatMap(validation -> validation.build().stream()) + .collect(toImmutableList()); + } + } + + public static class ParquetWriteValidationBuilder + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ParquetWriteValidationBuilder.class).instanceSize(); + private static final int COLUMN_DESCRIPTOR_INSTANCE_SIZE = ClassLayout.parseClass(ColumnDescriptor.class).instanceSize(); + private static final int PRIMITIVE_TYPE_INSTANCE_SIZE = ClassLayout.parseClass(PrimitiveType.class).instanceSize(); + + private final List types; + private final List columnNames; + private final WriteChecksumBuilder checksum; + + private String createdBy; + private Optional timeZoneId = Optional.empty(); + private List columns; + private List rowGroups; + private long retainedSize = INSTANCE_SIZE; + + public ParquetWriteValidationBuilder(List types, List columnNames) + { + this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); + this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null")); + checkArgument( + types.size() == columnNames.size(), + "Types count %s did not match column names count %s", + types.size(), + columnNames.size()); + this.checksum = new WriteChecksumBuilder(types); + retainedSize += estimatedSizeOf(types, type -> 0) + + estimatedSizeOf(columnNames, SizeOf::estimatedSizeOf); + } + + public long getRetainedSize() + { + return retainedSize; + } + + public void setCreatedBy(String createdBy) + { + this.createdBy = createdBy; + retainedSize += estimatedSizeOf(createdBy); + } + + public void setTimeZone(Optional timeZoneId) + { + this.timeZoneId = timeZoneId; + timeZoneId.ifPresent(id -> retainedSize += estimatedSizeOf(id)); + } + + public void setColumns(List columns) + { + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + retainedSize += estimatedSizeOf(columns, descriptor -> { + return COLUMN_DESCRIPTOR_INSTANCE_SIZE + + (2 * SIZE_OF_INT) // maxRep, maxDef + + estimatedSizeOfStringArray(descriptor.getPath()) + + PRIMITIVE_TYPE_INSTANCE_SIZE + + (3 * SIZE_OF_INT); // primitive, length, columnOrder + }); + } + + public void setRowGroups(List rowGroups) + { + this.rowGroups = ImmutableList.copyOf(requireNonNull(rowGroups, "rowGroups is null")); + } + + public void addPage(Page page) + { + checksum.addPage(page); + } + + public ParquetWriteValidation build() + { + return new ParquetWriteValidation( + createdBy, + timeZoneId, + columns, + rowGroups, + checksum.build(), + types, + columnNames); + } + } + + // parquet-mr IndexReference class lacks equals and toString implementations + static class IndexReferenceValidation + { + private final long offset; + private final int length; + + private IndexReferenceValidation(long offset, int length) + { + this.offset = offset; + this.length = length; + } + + static IndexReferenceValidation fromIndexReference(IndexReference indexReference) + { + return new IndexReferenceValidation(indexReference.getOffset(), indexReference.getLength()); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IndexReferenceValidation that = (IndexReferenceValidation) o; + return offset == that.offset && length == that.length; + } + + @Override + public int hashCode() + { + return Objects.hash(offset, length); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("offset", offset) + .add("length", length) + .toString(); + } + } + + private static void verifyColumnMetadataMatch( + boolean condition, + String name, + T actual, + ColumnPath path, + int rowGroup, + ParquetDataSourceId dataSourceId, + U expected) + throws ParquetCorruptionException + { + if (!condition) { + throw new ParquetCorruptionException( + dataSourceId, + "%s [%s] for column %s in row group %d did not match [%s]", + name, + actual, + path, + rowGroup, + expected); + } + } + + private static boolean areEncodingsSame(Set actual, List expected) + { + return actual.equals(expected.stream().map(METADATA_CONVERTER::getEncoding).collect(toImmutableSet())); + } + + private static boolean areStatisticsSame(org.apache.parquet.column.statistics.Statistics actual, org.apache.parquet.format.Statistics expected) + { + Statistics.Builder expectedStatsBuilder = Statistics.getBuilderForReading(actual.type()); + if (expected.isSetNull_count()) { + expectedStatsBuilder.withNumNulls(expected.getNull_count()); + } + if (expected.isSetMin_value()) { + expectedStatsBuilder.withMin(expected.getMin_value()); + } + if (expected.isSetMax_value()) { + expectedStatsBuilder.withMax(expected.getMax_value()); + } + return actual.equals(expectedStatsBuilder.build()); + } + + private static void validateColumnDescriptorsSame(ColumnDescriptor actual, ColumnDescriptor expected, ParquetDataSourceId dataSourceId) + throws ParquetCorruptionException + { + // Column names are lower-cased by MetadataReader#readFooter + validateParquet( + Arrays.equals(actual.getPath(), Arrays.stream(expected.getPath()).map(field -> field.toLowerCase(Locale.ENGLISH)).toArray()), + dataSourceId, + "Column path %s did not match expected column path %s", + actual.getPath(), + expected.getPath()); + + validateParquet( + actual.getMaxDefinitionLevel() == expected.getMaxDefinitionLevel(), + dataSourceId, + "Column %s max definition level %d did not match expected max definition level %d", + actual.getPath(), + actual.getMaxDefinitionLevel(), + expected.getMaxDefinitionLevel()); + + validateParquet( + actual.getMaxRepetitionLevel() == expected.getMaxRepetitionLevel(), + dataSourceId, + "Column %s max repetition level %d did not match expected max repetition level %d", + actual.getPath(), + actual.getMaxRepetitionLevel(), + expected.getMaxRepetitionLevel()); + + PrimitiveType actualPrimitiveType = actual.getPrimitiveType(); + PrimitiveType expectedPrimitiveType = expected.getPrimitiveType(); + // We don't use PrimitiveType#equals directly because column names are lower-cased by MetadataReader#readFooter + validateParquet( + actualPrimitiveType.getPrimitiveTypeName().equals(expectedPrimitiveType.getPrimitiveTypeName()) + && actualPrimitiveType.getTypeLength() == expectedPrimitiveType.getTypeLength() + && actualPrimitiveType.getRepetition().equals(expectedPrimitiveType.getRepetition()) + && actualPrimitiveType.getName().equals(expectedPrimitiveType.getName().toLowerCase(Locale.ENGLISH)) + && Objects.equals(actualPrimitiveType.getLogicalTypeAnnotation(), expectedPrimitiveType.getLogicalTypeAnnotation()), + dataSourceId, + "Column %s primitive type %s did not match expected primitive type %s", + actual.getPath(), + actualPrimitiveType, + expectedPrimitiveType); + } + + private static long estimatedSizeOfStringArray(String[] path) + { + long size = sizeOf(path); + for (String field : path) { + size += estimatedSizeOf(field); + } + return size; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ValidationHash.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ValidationHash.java new file mode 100644 index 000000000000..86cb0ceceafa --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ValidationHash.java @@ -0,0 +1,144 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet; + +import io.trino.spi.block.Block; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodType; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.type.StandardTypes.ARRAY; +import static io.trino.spi.type.StandardTypes.MAP; +import static io.trino.spi.type.StandardTypes.ROW; +import static java.lang.invoke.MethodHandles.lookup; +import static java.util.Objects.requireNonNull; + +/** + * Based on io.trino.rcfile.ValidationHash and io.trino.orc.ValidationHash + * with minor differences in handling of timestamp and map types. + */ +class ValidationHash +{ + // This value is a large arbitrary prime + private static final long NULL_HASH_CODE = 0x6e3efbd56c16a0cbL; + + private static final MethodHandle MAP_HASH; + private static final MethodHandle ARRAY_HASH; + private static final MethodHandle ROW_HASH; + + static { + try { + MAP_HASH = lookup().findStatic( + ValidationHash.class, + "mapHash", + MethodType.methodType(long.class, Type.class, ValidationHash.class, ValidationHash.class, Block.class, int.class)); + ARRAY_HASH = lookup().findStatic( + ValidationHash.class, + "arrayHash", + MethodType.methodType(long.class, Type.class, ValidationHash.class, Block.class, int.class)); + ROW_HASH = lookup().findStatic( + ValidationHash.class, + "rowHash", + MethodType.methodType(long.class, Type.class, ValidationHash[].class, Block.class, int.class)); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + // This should really come from the environment, but there is not good way to get a value here + private static final TypeOperators VALIDATION_TYPE_OPERATORS_CACHE = new TypeOperators(); + + public static ValidationHash createValidationHash(Type type) + { + requireNonNull(type, "type is null"); + if (type.getTypeSignature().getBase().equals(MAP)) { + ValidationHash keyHash = createValidationHash(type.getTypeParameters().get(0)); + ValidationHash valueHash = createValidationHash(type.getTypeParameters().get(1)); + return new ValidationHash(MAP_HASH.bindTo(type).bindTo(keyHash).bindTo(valueHash)); + } + + if (type.getTypeSignature().getBase().equals(ARRAY)) { + ValidationHash elementHash = createValidationHash(type.getTypeParameters().get(0)); + return new ValidationHash(ARRAY_HASH.bindTo(type).bindTo(elementHash)); + } + + if (type.getTypeSignature().getBase().equals(ROW)) { + ValidationHash[] fieldHashes = type.getTypeParameters().stream() + .map(ValidationHash::createValidationHash) + .toArray(ValidationHash[]::new); + return new ValidationHash(ROW_HASH.bindTo(type).bindTo(fieldHashes)); + } + + return new ValidationHash(VALIDATION_TYPE_OPERATORS_CACHE.getHashCodeOperator(type, InvocationConvention.simpleConvention(FAIL_ON_NULL, BLOCK_POSITION))); + } + + private final MethodHandle hashCodeOperator; + + private ValidationHash(MethodHandle hashCodeOperator) + { + this.hashCodeOperator = requireNonNull(hashCodeOperator, "hashCodeOperator is null"); + } + + public long hash(Block block, int position) + { + if (block.isNull(position)) { + return NULL_HASH_CODE; + } + try { + return (long) hashCodeOperator.invokeExact(block, position); + } + catch (Throwable throwable) { + throwIfUnchecked(throwable); + throw new RuntimeException(throwable); + } + } + + private static long mapHash(Type type, ValidationHash keyHash, ValidationHash valueHash, Block block, int position) + { + Block mapBlock = (Block) type.getObject(block, position); + long hash = 0; + for (int i = 0; i < mapBlock.getPositionCount(); i += 2) { + hash = 31 * hash + keyHash.hash(mapBlock, i); + hash = 31 * hash + valueHash.hash(mapBlock, i + 1); + } + return hash; + } + + private static long arrayHash(Type type, ValidationHash elementHash, Block block, int position) + { + Block array = (Block) type.getObject(block, position); + long hash = 0; + for (int i = 0; i < array.getPositionCount(); i++) { + hash = 31 * hash + elementHash.hash(array, i); + } + return hash; + } + + private static long rowHash(Type type, ValidationHash[] fieldHashes, Block block, int position) + { + Block row = (Block) type.getObject(block, position); + long hash = 0; + for (int i = 0; i < row.getPositionCount(); i++) { + hash = 31 * hash + fieldHashes[i].hash(row, i); + } + return hash; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java index ad267e7c8375..4dc824c16c12 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/MetadataReader.java @@ -16,7 +16,10 @@ import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetWriteValidation; import org.apache.parquet.CorruptStatistics; import org.apache.parquet.column.statistics.BinaryStatistics; import org.apache.parquet.format.ColumnChunk; @@ -61,6 +64,7 @@ import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static org.apache.hadoop.hive.ql.io.parquet.write.DataWritableWriteSupport.WRITER_TIMEZONE; import static org.apache.parquet.format.Util.readFileMetaData; import static org.apache.parquet.format.converter.ParquetMetadataConverterUtil.getLogicalTypeAnnotation; @@ -75,7 +79,7 @@ public final class MetadataReader private MetadataReader() {} - public static ParquetMetadata readFooter(ParquetDataSource dataSource) + public static ParquetMetadata readFooter(ParquetDataSource dataSource, Optional parquetWriteValidation) throws IOException { // Parquet File Layout: @@ -165,7 +169,12 @@ public static ParquetMetadata readFooter(ParquetDataSource dataSource) keyValueMetaData.put(keyValue.key, keyValue.value); } } - return new ParquetMetadata(new org.apache.parquet.hadoop.metadata.FileMetaData(messageType, keyValueMetaData, fileMetaData.getCreated_by()), blocks); + org.apache.parquet.hadoop.metadata.FileMetaData parquetFileMetadata = new org.apache.parquet.hadoop.metadata.FileMetaData( + messageType, + keyValueMetaData, + fileMetaData.getCreated_by()); + validateFileMetadata(dataSource.getId(), parquetFileMetadata, parquetWriteValidation); + return new ParquetMetadata(parquetFileMetadata, blocks); } private static MessageType readParquetSchema(List schema) @@ -380,4 +389,17 @@ private static IndexReference toOffsetIndexReference(ColumnChunk columnChunk) } return null; } + + private static void validateFileMetadata(ParquetDataSourceId dataSourceId, org.apache.parquet.hadoop.metadata.FileMetaData fileMetaData, Optional parquetWriteValidation) + throws ParquetCorruptionException + { + if (parquetWriteValidation.isEmpty()) { + return; + } + ParquetWriteValidation writeValidation = parquetWriteValidation.get(); + writeValidation.validateTimeZone( + dataSourceId, + Optional.ofNullable(fileMetaData.getKeyValueMetaData().get(WRITER_TIMEZONE))); + writeValidation.validateColumns(dataSourceId, fileMetaData.getSchema()); + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index b325fe508447..d653937a4247 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -29,6 +29,7 @@ import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.PrimitiveField; import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.FilteredOffsetIndex.OffsetRange; @@ -74,6 +75,10 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.parquet.ParquetValidationUtils.validateParquet; +import static io.trino.parquet.ParquetWriteValidation.StatisticsValidation; +import static io.trino.parquet.ParquetWriteValidation.StatisticsValidation.createStatisticsValidationBuilder; +import static io.trino.parquet.ParquetWriteValidation.WriteChecksumBuilder; +import static io.trino.parquet.ParquetWriteValidation.WriteChecksumBuilder.createWriteChecksumBuilder; import static io.trino.parquet.reader.ListColumnReader.calculateCollectionOffsets; import static java.lang.Math.max; import static java.lang.Math.min; @@ -122,6 +127,9 @@ public class ParquetReader private AggregatedMemoryContext currentRowGroupMemoryContext; private final Multimap chunkReaders; private final List> columnIndexStore; + private final Optional writeValidation; + private final Optional writeChecksumBuilder; + private final Optional rowGroupStatisticsValidation; private final List blockRowRanges; private final Map paths = new HashMap<>(); private final ParquetBlockFactory blockFactory; @@ -139,7 +147,7 @@ public ParquetReader( Function exceptionTransform) throws IOException { - this(fileCreatedBy, columnFields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, exceptionTransform, Optional.empty(), nCopies(blocks.size(), Optional.empty())); + this(fileCreatedBy, columnFields, blocks, firstRowsOfBlocks, dataSource, timeZone, memoryContext, options, exceptionTransform, Optional.empty(), nCopies(blocks.size(), Optional.empty()), Optional.empty()); } public ParquetReader( @@ -153,7 +161,8 @@ public ParquetReader( ParquetReaderOptions options, Function exceptionTransform, Optional parquetPredicate, - List> columnIndexStore) + List> columnIndexStore, + Optional writeValidation) throws IOException { this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); @@ -172,6 +181,16 @@ public ParquetReader( checkArgument(blocks.size() == firstRowsOfBlocks.size(), "elements of firstRowsOfBlocks must correspond to blocks"); + this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); + validateWrite( + validation -> fileCreatedBy.equals(Optional.of(validation.getCreatedBy())), + "Expected created by %s, found %s", + writeValidation.map(ParquetWriteValidation::getCreatedBy), + fileCreatedBy); + validateBlockMetadata(blocks); + this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(validation.getTypes())); + this.rowGroupStatisticsValidation = writeValidation.map(validation -> createStatisticsValidationBuilder(validation.getTypes())); + this.blockRowRanges = listWithNulls(this.blocks.size()); for (PrimitiveField field : primitiveFields) { ColumnDescriptor columnDescriptor = field.getDescriptor(); @@ -231,9 +250,15 @@ public void close() freeCurrentRowGroupBuffers(); currentRowGroupMemoryContext.close(); dataSource.close(); + + if (writeChecksumBuilder.isPresent()) { + ParquetWriteValidation parquetWriteValidation = writeValidation.orElseThrow(); + parquetWriteValidation.validateChecksum(dataSource.getId(), writeChecksumBuilder.get().build()); + } } public Page nextPage() + throws IOException { int batchSize = nextBatch(); if (batchSize <= 0) { @@ -246,7 +271,9 @@ public Page nextPage() Field field = columnFields.get(channel); blocks[channel] = blockFactory.createBlock(batchSize, () -> readBlock(field)); } - return new Page(batchSize, blocks); + Page page = new Page(batchSize, blocks); + validateWritePageChecksum(page); + return page; } /** @@ -258,6 +285,7 @@ public long lastBatchStartRow() } private int nextBatch() + throws IOException { if (nextRowInGroup >= currentGroupRowCount && !advanceToNextRowGroup()) { return -1; @@ -273,10 +301,18 @@ private int nextBatch() } private boolean advanceToNextRowGroup() + throws IOException { currentRowGroupMemoryContext.close(); currentRowGroupMemoryContext = memoryContext.newAggregatedMemoryContext(); freeCurrentRowGroupBuffers(); + + if (currentRowGroup >= 0 && rowGroupStatisticsValidation.isPresent()) { + StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.get(); + writeValidation.orElseThrow().validateRowGroupStatistics(dataSource.getId(), currentBlockMetadata, statisticsValidation.build()); + statisticsValidation.reset(); + } + currentRowGroup++; if (currentRowGroup == blocks.size()) { return false; @@ -541,4 +577,29 @@ private RowRanges getRowRanges(FilterPredicate filter, int blockIndex) } return rowRanges; } + + private void validateWritePageChecksum(Page page) + { + if (writeChecksumBuilder.isPresent()) { + page = page.getLoadedPage(); + writeChecksumBuilder.get().addPage(page); + rowGroupStatisticsValidation.orElseThrow().addPage(page); + } + } + + private void validateBlockMetadata(List blockMetaData) + throws ParquetCorruptionException + { + if (writeValidation.isPresent()) { + writeValidation.get().validateBlocksMetadata(dataSource.getId(), blockMetaData); + } + } + + private void validateWrite(java.util.function.Predicate test, String messageFormat, Object... args) + throws ParquetCorruptionException + { + if (writeValidation.isPresent() && !test.test(writeValidation.get())) { + throw new ParquetCorruptionException(dataSource.getId(), "Write validation failed: " + messageFormat, args); + } + } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index 950f1fa4c8a8..934191dab0a2 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -20,6 +20,13 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; +import io.trino.parquet.Field; +import io.trino.parquet.ParquetCorruptionException; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.ParquetWriteValidation; +import io.trino.parquet.reader.MetadataReader; +import io.trino.parquet.reader.ParquetReader; import io.trino.parquet.writer.ColumnWriter.BufferData; import io.trino.spi.Page; import io.trino.spi.type.Type; @@ -29,7 +36,10 @@ import org.apache.parquet.format.KeyValue; import org.apache.parquet.format.RowGroup; import org.apache.parquet.format.Util; +import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.MessageType; import org.joda.time.DateTimeZone; import org.openjdk.jol.info.ClassLayout; @@ -40,19 +50,27 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.Slices.wrappedBuffer; import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.trino.parquet.ParquetTypeUtils.constructField; +import static io.trino.parquet.ParquetTypeUtils.getColumnIO; +import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; +import static io.trino.parquet.ParquetWriteValidation.ParquetWriteValidationBuilder; import static io.trino.parquet.writer.ParquetDataOutput.createDataOutput; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; import static org.apache.hadoop.hive.ql.io.parquet.write.DataWritableWriteSupport.WRITER_TIMEZONE; import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0; @@ -74,6 +92,7 @@ public class ParquetWriter private final Optional parquetTimeZone; private final ImmutableList.Builder rowGroupBuilder = ImmutableList.builder(); + private final Optional validationBuilder; private List columnWriters; private int rows; @@ -90,17 +109,23 @@ public ParquetWriter( ParquetWriterOptions writerOption, CompressionCodecName compressionCodecName, String trinoVersion, - Optional parquetTimeZone) + Optional parquetTimeZone, + Optional validationBuilder) { + this.validationBuilder = requireNonNull(validationBuilder, "validationBuilder is null"); this.outputStream = new OutputStreamSliceOutput(requireNonNull(outputStream, "outputstream is null")); this.messageType = requireNonNull(messageType, "messageType is null"); this.primitiveTypes = requireNonNull(primitiveTypes, "primitiveTypes is null"); this.writerOption = requireNonNull(writerOption, "writerOption is null"); this.compressionCodecName = requireNonNull(compressionCodecName, "compressionCodecName is null"); this.parquetTimeZone = requireNonNull(parquetTimeZone, "parquetTimeZone is null"); + this.createdBy = formatCreatedBy(requireNonNull(trinoVersion, "trinoVersion is null")); + + recordValidation(validation -> validation.setTimeZone(parquetTimeZone.map(DateTimeZone::getID))); + recordValidation(validation -> validation.setColumns(messageType.getColumns())); + recordValidation(validation -> validation.setCreatedBy(createdBy)); initColumnWriters(); this.chunkMaxLogicalBytes = max(1, CHUNK_MAX_BYTES / 2); - this.createdBy = formatCreatedBy(requireNonNull(trinoVersion, "trinoVersion is null")); } public long getWrittenBytes() @@ -117,7 +142,8 @@ public long getRetainedBytes() { return INSTANCE_SIZE + outputStream.getRetainedSize() + - columnWriters.stream().mapToLong(ColumnWriter::getRetainedBytes).sum(); + columnWriters.stream().mapToLong(ColumnWriter::getRetainedBytes).sum() + + validationBuilder.map(ParquetWriteValidationBuilder::getRetainedSize).orElse(0L); } public void write(Page page) @@ -131,6 +157,9 @@ public void write(Page page) checkArgument(page.getChannelCount() == columnWriters.size()); + Page validationPage = page; + recordValidation(validation -> validation.addPage(validationPage)); + while (page != null) { int chunkRows = min(page.getPositionCount(), writerOption.getBatchSize()); Page chunk = page.getRegion(0, chunkRows); @@ -190,6 +219,70 @@ public void close() bufferedBytes = 0; } + public void validate(ParquetDataSource input) + throws ParquetCorruptionException + { + checkState(validationBuilder.isPresent(), "validation is not enabled"); + ParquetWriteValidation writeValidation = validationBuilder.get().build(); + try { + ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation)); + try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { + for (Page page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { + // fully load the page + page.getLoadedPage(); + } + } + } + catch (IOException e) { + if (e instanceof ParquetCorruptionException) { + throw (ParquetCorruptionException) e; + } + throw new ParquetCorruptionException(input.getId(), "Validation failed with exception %s", e); + } + } + + private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetadata parquetMetadata, ParquetWriteValidation writeValidation) + throws IOException + { + org.apache.parquet.hadoop.metadata.FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); + MessageColumnIO messageColumnIO = getColumnIO(fileMetaData.getSchema(), fileMetaData.getSchema()); + ImmutableList.Builder columnFields = ImmutableList.builder(); + for (int i = 0; i < writeValidation.getTypes().size(); i++) { + columnFields.add(constructField( + writeValidation.getTypes().get(i), + lookupColumnByName(messageColumnIO, writeValidation.getColumnNames().get(i))) + .orElseThrow()); + } + long nextStart = 0; + ImmutableList.Builder blockStartsBuilder = ImmutableList.builder(); + for (BlockMetaData block : parquetMetadata.getBlocks()) { + blockStartsBuilder.add(nextStart); + nextStart += block.getRowCount(); + } + List blockStarts = blockStartsBuilder.build(); + return new ParquetReader( + Optional.ofNullable(fileMetaData.getCreatedBy()), + columnFields.build(), + parquetMetadata.getBlocks(), + blockStarts, + input, + parquetTimeZone.orElseThrow(), + newSimpleAggregatedMemoryContext(), + new ParquetReaderOptions(), + exception -> { + throwIfUnchecked(exception); + return new RuntimeException(exception); + }, + Optional.empty(), + nCopies(blockStarts.size(), Optional.empty()), + Optional.of(writeValidation)); + } + + private void recordValidation(Consumer task) + { + validationBuilder.ifPresent(task); + } + // Parquet File Layout: // // MAGIC @@ -239,7 +332,9 @@ private void writeFooter() throws IOException { checkState(closed); - Slice footer = getFooter(rowGroupBuilder.build(), messageType); + List rowGroups = rowGroupBuilder.build(); + Slice footer = getFooter(rowGroups, messageType); + recordValidation(validation -> validation.setRowGroups(rowGroups)); createDataOutput(footer).writeData(outputStream); Slice footerSize = Slices.allocate(SIZE_OF_INT); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index f9a0ce2a5f2d..cc2ce8f27f68 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -225,11 +225,12 @@ private FileWriter createParquetFileWriter(FileSystem fileSystem, Path path, Lis }) .collect(toImmutableList()); + List dataColumnNames = dataColumns.stream() + .map(DeltaLakeColumnHandle::getName) + .collect(toImmutableList()); ParquetSchemaConverter schemaConverter = new ParquetSchemaConverter( parquetTypes, - dataColumns.stream() - .map(DeltaLakeColumnHandle::getName) - .collect(toImmutableList()), + dataColumnNames, false, false); @@ -237,12 +238,14 @@ private FileWriter createParquetFileWriter(FileSystem fileSystem, Path path, Lis fileSystem.create(path), rollbackAction, parquetTypes, + dataColumnNames, schemaConverter.getMessageType(), schemaConverter.getPrimitiveTypes(), parquetWriterOptions, IntStream.range(0, dataColumns.size()).toArray(), compressionCodecName, trinoVersion, + Optional.empty(), Optional.empty()); } catch (IOException e) { @@ -310,7 +313,8 @@ private ReaderPageSource createParquetPageSource(Path path) true, parquetDateTimeZone, new FileFormatDataSourceStats(), - new ParquetReaderOptions()); + new ParquetReaderOptions(), + Optional.empty()); } private static class FileDeletion diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java index 85dc41403550..b4d0cf0218c1 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSink.java @@ -496,12 +496,14 @@ private FileWriter createParquetFileWriter(Path path) fileSystem.create(path), rollbackAction, parquetTypes, + dataColumnNames, schemaConverter.getMessageType(), schemaConverter.getPrimitiveTypes(), parquetWriterOptions, identityMapping, compressionCodecName, trinoVersion, + Optional.empty(), Optional.empty()); } catch (IOException e) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index b3d7b321eba7..50039d5859dc 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -180,7 +180,8 @@ public ConnectorPageSource createPageSource( parquetDateTimeZone, fileFormatDataSourceStats, parquetReaderOptions.withMaxReadBlockSize(getParquetMaxReadBlockSize(session)) - .withUseColumnIndex(isParquetUseColumnIndex(session))); + .withUseColumnIndex(isParquetUseColumnIndex(session)), + Optional.empty()); verify(pageSource.getReaderColumns().isEmpty(), "All columns expected to be base columns"); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java index 46b33bb5d785..6aee909915af 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java @@ -578,7 +578,8 @@ private ReaderPageSource createParquetPageSource(TupleDomain p parquetDateTimeZone, new FileFormatDataSourceStats(), parquetReaderOptions.withMaxReadBlockSize(getParquetMaxReadBlockSize(this.session)) - .withUseColumnIndex(isParquetUseColumnIndex(this.session))); + .withUseColumnIndex(isParquetUseColumnIndex(this.session)), + Optional.empty()); } private DeltaLakeWriter createWriter(Path targetFile, List allColumns, List dataColumns) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 41f7c8f856c5..8a52ffc5efa2 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -181,7 +181,8 @@ public CheckpointEntryIterator( true, DateTimeZone.UTC, stats, - parquetReaderOptions); + parquetReaderOptions, + Optional.empty()); verify(pageSource.getReaderColumns().isEmpty(), "All columns expected to be base columns"); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java index 22052f828b87..98019b128764 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java @@ -89,6 +89,7 @@ public final class HiveSessionProperties private static final String PARQUET_WRITER_BLOCK_SIZE = "parquet_writer_block_size"; private static final String PARQUET_WRITER_PAGE_SIZE = "parquet_writer_page_size"; private static final String PARQUET_WRITER_BATCH_SIZE = "parquet_writer_batch_size"; + private static final String PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE = "parquet_optimized_writer_validation_percentage"; private static final String MAX_SPLIT_SIZE = "max_split_size"; private static final String MAX_INITIAL_SPLIT_SIZE = "max_initial_split_size"; private static final String RCFILE_OPTIMIZED_WRITER_VALIDATE = "rcfile_optimized_writer_validate"; @@ -338,6 +339,23 @@ public HiveSessionProperties( "Parquet: Maximum number of rows passed to the writer in each batch", parquetWriterConfig.getBatchSize(), false), + new PropertyMetadata<>( + PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE, + "Parquet: sample percentage for validation of written files", + DOUBLE, + Double.class, + parquetWriterConfig.getValidationPercentage(), + false, + value -> { + double doubleValue = (double) value; + if (doubleValue < 0.0 || doubleValue > 100.0) { + throw new TrinoException( + INVALID_SESSION_PROPERTY, + format("%s must be between 0.0 and 100.0 inclusive: %s", PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE, doubleValue)); + } + return doubleValue; + }, + value -> value), dataSizeProperty( MAX_SPLIT_SIZE, "Max split size", @@ -685,6 +703,13 @@ public static int getParquetBatchSize(ConnectorSession session) return session.getProperty(PARQUET_WRITER_BATCH_SIZE, Integer.class); } + public static boolean isParquetOptimizedWriterValidate(ConnectorSession session) + { + double percentage = session.getProperty(PARQUET_OPTIMIZED_WRITER_VALIDATION_PERCENTAGE, Double.class); + checkArgument(percentage >= 0.0 && percentage <= 100.0); + return ThreadLocalRandom.current().nextDouble(100) < percentage; + } + public static DataSize getMaxSplitSize(ConnectorSession session) { return session.getProperty(MAX_SPLIT_SIZE, DataSize.class); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java index fe31b0ec367a..67ab1d3f4d36 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriter.java @@ -14,6 +14,7 @@ package io.trino.plugin.hive.parquet; import com.google.common.collect.ImmutableList; +import io.trino.parquet.ParquetDataSource; import io.trino.parquet.writer.ParquetWriter; import io.trino.parquet.writer.ParquetWriterOptions; import io.trino.plugin.hive.FileWriter; @@ -31,40 +32,51 @@ import java.io.IOException; import java.io.OutputStream; import java.io.UncheckedIOException; +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.Callable; +import java.util.function.Supplier; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.trino.parquet.ParquetWriteValidation.ParquetWriteValidationBuilder; import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR; import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_DATA_ERROR; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITE_VALIDATION_FAILED; import static java.util.Objects.requireNonNull; public class ParquetFileWriter implements FileWriter { private static final int INSTANCE_SIZE = ClassLayout.parseClass(ParquetFileWriter.class).instanceSize(); + private static final ThreadMXBean THREAD_MX_BEAN = ManagementFactory.getThreadMXBean(); private final ParquetWriter parquetWriter; private final Callable rollbackAction; private final int[] fileInputColumnIndexes; private final List nullBlocks; + private final Optional> validationInputFactory; + private long validationCpuNanos; public ParquetFileWriter( OutputStream outputStream, Callable rollbackAction, List fileColumnTypes, + List fileColumnNames, MessageType messageType, Map, Type> primitiveTypes, ParquetWriterOptions parquetWriterOptions, int[] fileInputColumnIndexes, CompressionCodecName compressionCodecName, String trinoVersion, - Optional parquetTimeZone) + Optional parquetTimeZone, + Optional> validationInputFactory) { requireNonNull(outputStream, "outputStream is null"); requireNonNull(trinoVersion, "trinoVersion is null"); + this.validationInputFactory = requireNonNull(validationInputFactory, "validationInputFactory is null"); this.parquetWriter = new ParquetWriter( outputStream, @@ -73,7 +85,10 @@ public ParquetFileWriter( parquetWriterOptions, compressionCodecName, trinoVersion, - parquetTimeZone); + parquetTimeZone, + validationInputFactory.isPresent() + ? Optional.of(new ParquetWriteValidationBuilder(fileColumnTypes, fileColumnNames)) + : Optional.empty()); this.rollbackAction = requireNonNull(rollbackAction, "rollbackAction is null"); this.fileInputColumnIndexes = requireNonNull(fileInputColumnIndexes, "fileInputColumnIndexes is null"); @@ -136,6 +151,19 @@ public void commit() } throw new TrinoException(HIVE_WRITER_CLOSE_ERROR, "Error committing write parquet to Hive", e); } + + if (validationInputFactory.isPresent()) { + try { + try (ParquetDataSource input = validationInputFactory.get().get()) { + long startThreadCpuTime = THREAD_MX_BEAN.getCurrentThreadCpuTime(); + parquetWriter.validate(input); + validationCpuNanos += THREAD_MX_BEAN.getCurrentThreadCpuTime() - startThreadCpuTime; + } + } + catch (IOException | UncheckedIOException e) { + throw new TrinoException(HIVE_WRITE_VALIDATION_FAILED, e); + } + } } @Override @@ -157,7 +185,7 @@ public void rollback() @Override public long getValidationCpuNanos() { - return 0; + return validationCpuNanos; } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java index c8719ab71eb9..5f6d54eaf282 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetFileWriterFactory.java @@ -14,8 +14,12 @@ package io.trino.plugin.hive.parquet; import io.trino.hdfs.HdfsEnvironment; +import io.trino.parquet.ParquetDataSource; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.writer.ParquetSchemaConverter; import io.trino.parquet.writer.ParquetWriterOptions; +import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.FileWriter; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HiveFileWriterFactory; @@ -35,6 +39,8 @@ import org.apache.parquet.hadoop.ParquetOutputFormat; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.joda.time.DateTimeZone; +import org.weakref.jmx.Flatten; +import org.weakref.jmx.Managed; import javax.inject.Inject; @@ -44,11 +50,14 @@ import java.util.OptionalInt; import java.util.Properties; import java.util.concurrent.Callable; +import java.util.function.Supplier; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_INT96_TIMESTAMP_ENCODING; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_LEGACY_DECIMAL_ENCODING; import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_OPEN_ERROR; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITE_VALIDATION_FAILED; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; +import static io.trino.plugin.hive.HiveSessionProperties.isParquetOptimizedWriterValidate; import static io.trino.plugin.hive.util.HiveUtil.getColumnNames; import static io.trino.plugin.hive.util.HiveUtil.getColumnTypes; import static java.util.Objects.requireNonNull; @@ -61,18 +70,21 @@ public class ParquetFileWriterFactory private final NodeVersion nodeVersion; private final TypeManager typeManager; private final DateTimeZone parquetTimeZone; + private final FileFormatDataSourceStats readStats; @Inject public ParquetFileWriterFactory( HdfsEnvironment hdfsEnvironment, NodeVersion nodeVersion, TypeManager typeManager, - HiveConfig hiveConfig) + HiveConfig hiveConfig, + FileFormatDataSourceStats readStats) { this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.parquetTimeZone = hiveConfig.getParquetDateTimeZone(); + this.readStats = requireNonNull(readStats, "readStats is null"); } @Override @@ -127,23 +139,49 @@ public Optional createFileWriter( HIVE_PARQUET_USE_LEGACY_DECIMAL_ENCODING, HIVE_PARQUET_USE_INT96_TIMESTAMP_ENCODING); + Optional> validationInputFactory = Optional.empty(); + if (isParquetOptimizedWriterValidate(session)) { + validationInputFactory = Optional.of(() -> { + try { + return new HdfsParquetDataSource( + new ParquetDataSourceId(path.toString()), + fileSystem.getFileStatus(path).getLen(), + fileSystem.open(path), + readStats, + new ParquetReaderOptions()); + } + catch (IOException e) { + throw new TrinoException(HIVE_WRITE_VALIDATION_FAILED, e); + } + }); + } + return Optional.of(new ParquetFileWriter( fileSystem.create(path, false), rollbackAction, fileColumnTypes, + fileColumnNames, schemaConverter.getMessageType(), schemaConverter.getPrimitiveTypes(), parquetWriterOptions, fileInputColumnIndexes, compressionCodecName, nodeVersion.toString(), - Optional.of(parquetTimeZone))); + Optional.of(parquetTimeZone), + validationInputFactory)); } catch (IOException e) { throw new TrinoException(HIVE_WRITER_OPEN_ERROR, "Error creating Parquet file", e); } } + @Managed + @Flatten + public FileFormatDataSourceStats getReadStats() + { + return readStats; + } + private static CompressionCodecName getCompression(JobConf configuration) { String compressionName = configuration.get(ParquetOutputFormat.COMPRESSION); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java index f035b82ec8d7..3f1aed4ddd45 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java @@ -95,7 +95,7 @@ public Page getNextPage() try { page = getColumnAdaptationsPage(parquetReader.nextPage()); } - catch (RuntimeException e) { + catch (IOException | RuntimeException e) { closeAllSuppress(e, this); throw handleException(parquetReader.getDataSource().getId(), e); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index fcdf69b71b45..37eb876d144f 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -23,6 +23,7 @@ import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.ParquetWriteValidation; import io.trino.parquet.predicate.Predicate; import io.trino.parquet.reader.MetadataReader; import io.trino.parquet.reader.ParquetReader; @@ -177,7 +178,8 @@ public Optional createPageSource( stats, options.withIgnoreStatistics(isParquetIgnoreStatistics(session)) .withMaxReadBlockSize(getParquetMaxReadBlockSize(session)) - .withUseColumnIndex(isParquetUseColumnIndex(session)))); + .withUseColumnIndex(isParquetUseColumnIndex(session)), + Optional.empty())); } /** @@ -192,7 +194,8 @@ public static ReaderPageSource createPageSource( boolean useColumnNames, DateTimeZone timeZone, FileFormatDataSourceStats stats, - ParquetReaderOptions options) + ParquetReaderOptions options, + Optional parquetWriteValidation) { // Ignore predicates on partial columns for now. effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn()); @@ -204,7 +207,7 @@ public static ReaderPageSource createPageSource( try { dataSource = new TrinoParquetDataSource(inputFile, options, stats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, parquetWriteValidation); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); fileSchema = fileMetaData.getSchema(); @@ -267,7 +270,8 @@ && predicateMatches(parquetPredicate, block, dataSource, descriptorsByPath, parq options, exception -> handleException(dataSourceId, exception), Optional.of(parquetPredicate), - columnIndexes.build()), + columnIndexes.build(), + parquetWriteValidation), parquetReaderColumns); return new ReaderPageSource(parquetPageSource, readerProjections); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java index 2415e8fa41f6..1b2f7653aa4e 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetWriterConfig.java @@ -20,6 +20,9 @@ import io.trino.parquet.writer.ParquetWriterOptions; import org.apache.parquet.hadoop.ParquetWriter; +import javax.validation.constraints.DecimalMax; +import javax.validation.constraints.DecimalMin; + public class ParquetWriterConfig { private boolean parquetOptimizedWriterEnabled; @@ -27,6 +30,7 @@ public class ParquetWriterConfig private DataSize blockSize = DataSize.ofBytes(ParquetWriter.DEFAULT_BLOCK_SIZE); private DataSize pageSize = DataSize.ofBytes(ParquetWriter.DEFAULT_PAGE_SIZE); private int batchSize = ParquetWriterOptions.DEFAULT_BATCH_SIZE; + private double validationPercentage = 5; public DataSize getBlockSize() { @@ -80,4 +84,19 @@ public int getBatchSize() { return batchSize; } + + @DecimalMin("0.0") + @DecimalMax("100.0") + public double getValidationPercentage() + { + return validationPercentage; + } + + @Config("parquet.optimized-writer.validation-percentage") + @ConfigDescription("Percentage of parquet files to validate after write by re-reading the whole file") + public ParquetWriterConfig setValidationPercentage(double validationPercentage) + { + this.validationPercentage = validationPercentage; + return this; + } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java index af479988f7be..8fdaf598cf1b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java @@ -443,7 +443,11 @@ public void testParquetPageSourceGzip(int rowCount, long fileSizePadding) public void testOptimizedParquetWriter(int rowCount) throws Exception { - ConnectorSession session = getHiveSession(new HiveConfig(), new ParquetWriterConfig().setParquetOptimizedWriterEnabled(true)); + ConnectorSession session = getHiveSession( + new HiveConfig(), + new ParquetWriterConfig() + .setParquetOptimizedWriterEnabled(true) + .setValidationPercentage(100.0)); assertTrue(HiveSessionProperties.isParquetOptimizedWriterEnabled(session)); List testColumns = getTestColumnsSupportedByParquet(); @@ -451,7 +455,7 @@ public void testOptimizedParquetWriter(int rowCount) .withSession(session) .withColumns(testColumns) .withRowsCount(rowCount) - .withFileWriterFactory(new ParquetFileWriterFactory(HDFS_ENVIRONMENT, new NodeVersion("test-version"), TESTING_TYPE_MANAGER, new HiveConfig())) + .withFileWriterFactory(new ParquetFileWriterFactory(HDFS_ENVIRONMENT, new NodeVersion("test-version"), TESTING_TYPE_MANAGER, new HiveConfig(), STATS)) .isReadableByPageSource(PARQUET_PAGE_SOURCE_FACTORY); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java index 6463abcbceca..a0ccde2547b2 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/benchmark/StandardFileFormats.java @@ -319,7 +319,8 @@ public PrestoParquetFormatWriter(File targetFile, List columnNames, List ParquetWriterOptions.builder().build(), compressionCodec.getParquetCompressionCodec(), "test-version", - Optional.of(DateTimeZone.getDefault())); + Optional.of(DateTimeZone.getDefault()), + Optional.empty()); } @Override diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java index f2967164076c..566fc1d08f6a 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java @@ -22,9 +22,12 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; +import io.trino.parquet.ParquetDataSourceId; +import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.writer.ParquetSchemaConverter; import io.trino.parquet.writer.ParquetWriter; import io.trino.parquet.writer.ParquetWriterOptions; +import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HiveSessionProperties; import io.trino.plugin.hive.HiveStorageFormat; @@ -38,6 +41,7 @@ import io.trino.plugin.hive.parquet.write.TestMapredParquetOutputFormat; import io.trino.spi.Page; import io.trino.spi.PageBuilder; +import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorPageSource; @@ -59,6 +63,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.testing.TestingConnectorSession; +import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe; @@ -101,9 +106,11 @@ import static com.google.common.collect.Iterables.transform; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.hadoop.ConfigurationInstantiator.newEmptyConfiguration; +import static io.trino.parquet.ParquetWriteValidation.ParquetWriteValidationBuilder; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_INT96_TIMESTAMP_ENCODING; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_LEGACY_DECIMAL_ENCODING; import static io.trino.plugin.hive.AbstractTestHiveFileFormats.getFieldFromCursor; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITE_VALIDATION_FAILED; import static io.trino.plugin.hive.HiveSessionProperties.getParquetMaxReadBlockSize; import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static io.trino.plugin.hive.HiveTestUtils.getHiveSession; @@ -751,7 +758,8 @@ private static void writeParquetColumnTrino( .build(), compressionCodecName, "test-version", - Optional.of(DateTimeZone.getDefault())); + Optional.of(DateTimeZone.getDefault()), + Optional.of(new ParquetWriteValidationBuilder(types, columnNames))); PageBuilder pageBuilder = new PageBuilder(types); for (int i = 0; i < types.size(); ++i) { @@ -768,6 +776,19 @@ private static void writeParquetColumnTrino( pageBuilder.declarePositions(size); writer.write(pageBuilder.build()); writer.close(); + Path path = new Path(outputFile.getPath()); + FileSystem fileSystem = HDFS_ENVIRONMENT.getFileSystem(SESSION.getIdentity(), path, newEmptyConfiguration()); + try { + writer.validate(new HdfsParquetDataSource( + new ParquetDataSourceId(path.toString()), + fileSystem.getFileStatus(path).getLen(), + fileSystem.open(path), + new FileFormatDataSourceStats(), + new ParquetReaderOptions())); + } + catch (IOException e) { + throw new TrinoException(HIVE_WRITE_VALIDATION_FAILED, e); + } } private static void writeValue(Type type, BlockBuilder blockBuilder, Object value) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java index ce1f83ab3ea4..c72fca4a075e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetWriterConfig.java @@ -35,7 +35,8 @@ public void testDefaults() .setParquetOptimizedWriterEnabled(false) .setBlockSize(DataSize.ofBytes(ParquetWriter.DEFAULT_BLOCK_SIZE)) .setPageSize(DataSize.ofBytes(ParquetWriter.DEFAULT_PAGE_SIZE)) - .setBatchSize(ParquetWriterOptions.DEFAULT_BATCH_SIZE)); + .setBatchSize(ParquetWriterOptions.DEFAULT_BATCH_SIZE) + .setValidationPercentage(5)); } @Test @@ -60,13 +61,15 @@ public void testExplicitPropertyMappings() "parquet.experimental-optimized-writer.enabled", "true", "parquet.writer.block-size", "234MB", "parquet.writer.page-size", "11MB", - "parquet.writer.batch-size", "100"); + "parquet.writer.batch-size", "100", + "parquet.optimized-writer.validation-percentage", "10"); ParquetWriterConfig expected = new ParquetWriterConfig() .setParquetOptimizedWriterEnabled(true) .setBlockSize(DataSize.of(234, MEGABYTE)) .setPageSize(DataSize.of(11, MEGABYTE)) - .setBatchSize(100); + .setBatchSize(100) + .setValidationPercentage(10); assertFullMapping(properties, expected); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java index bc0562c4c6ba..52c370983911 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergFileWriterFactory.java @@ -188,6 +188,7 @@ private IcebergFileWriter createParquetWriter( outputStream, rollbackAction, fileColumnTypes, + fileColumnNames, convert(icebergSchema, "table"), makeTypeMap(fileColumnTypes, fileColumnNames), parquetWriterOptions, diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 7fb9dd54e506..4ab1227a7cf3 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -930,7 +930,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( ParquetDataSource dataSource = null; try { dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats); - ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource); + ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); if (nameMapping.isPresent() && !ParquetSchemaUtil.hasIds(fileSchema)) { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java index 446fe586a2be..c4d7ec0e34d2 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetFileWriter.java @@ -45,6 +45,7 @@ public IcebergParquetFileWriter( OutputStream outputStream, Callable rollbackAction, List fileColumnTypes, + List fileColumnNames, MessageType messageType, Map, Type> primitiveTypes, ParquetWriterOptions parquetWriterOptions, @@ -57,12 +58,14 @@ public IcebergParquetFileWriter( super(outputStream, rollbackAction, fileColumnTypes, + fileColumnNames, messageType, primitiveTypes, parquetWriterOptions, fileInputColumnIndexes, compressionCodecName, trinoVersion, + Optional.empty(), Optional.empty()); this.metricsConfig = requireNonNull(metricsConfig, "metricsConfig is null"); this.outputPath = requireNonNull(outputPath, "outputPath is null");