diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java index d31bee6f2e6c..dfa8e00bf9d0 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/PrimitiveColumnWriter.java @@ -20,6 +20,7 @@ import io.trino.parquet.writer.repdef.RepLevelIterable; import io.trino.parquet.writer.repdef.RepLevelIterables; import io.trino.parquet.writer.valuewriter.PrimitiveValueWriter; +import io.trino.spi.block.Block; import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; @@ -49,7 +50,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.parquet.writer.ParquetCompressor.getCompressor; import static io.trino.parquet.writer.ParquetDataOutput.createDataOutput; -import static io.trino.parquet.writer.repdef.RepLevelIterables.getIterator; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; import static org.apache.parquet.bytes.BytesInput.copy; @@ -113,36 +113,52 @@ public void writeBlock(ColumnChunk columnChunk) throws IOException { checkState(!closed); - - ColumnChunk current = new ColumnChunk(columnChunk.getBlock(), - ImmutableList.builder() - .addAll(columnChunk.getDefLevelIterables()) - .add(DefLevelIterables.of(columnChunk.getBlock(), maxDefinitionLevel)) - .build(), - ImmutableList.builder() - .addAll(columnChunk.getRepLevelIterables()) - .add(RepLevelIterables.of(columnChunk.getBlock())) - .build()); - // write values primitiveValueWriter.write(columnChunk.getBlock()); - // write definition levels - Iterator defIterator = DefLevelIterables.getIterator(current.getDefLevelIterables()); - while (defIterator.hasNext()) { - int next = defIterator.next(); - definitionLevelWriter.writeInteger(next); - if (next != maxDefinitionLevel) { - currentPageNullCounts++; + if (columnChunk.getDefLevelIterables().isEmpty()) { + // write definition levels for flat data types + Block block = columnChunk.getBlock(); + if (!block.mayHaveNull()) { + for (int position = 0; position < block.getPositionCount(); position++) { + definitionLevelWriter.writeInteger(maxDefinitionLevel); + } + } + else { + for (int position = 0; position < block.getPositionCount(); position++) { + byte isNull = (byte) (block.isNull(position) ? 1 : 0); + definitionLevelWriter.writeInteger(maxDefinitionLevel - isNull); + currentPageNullCounts += isNull; + } + } + valueCount += block.getPositionCount(); + } + else { + // write definition levels for nested data types + Iterator defIterator = DefLevelIterables.getIterator(ImmutableList.builder() + .addAll(columnChunk.getDefLevelIterables()) + .add(DefLevelIterables.of(columnChunk.getBlock(), maxDefinitionLevel)) + .build()); + while (defIterator.hasNext()) { + int next = defIterator.next(); + definitionLevelWriter.writeInteger(next); + if (next != maxDefinitionLevel) { + currentPageNullCounts++; + } + valueCount++; } - valueCount++; } - // write repetition levels - Iterator repIterator = getIterator(current.getRepLevelIterables()); - while (repIterator.hasNext()) { - int next = repIterator.next(); - repetitionLevelWriter.writeInteger(next); + if (columnDescriptor.getMaxRepetitionLevel() > 0) { + // write repetition levels for nested types + Iterator repIterator = RepLevelIterables.getIterator(ImmutableList.builder() + .addAll(columnChunk.getRepLevelIterables()) + .add(RepLevelIterables.of(columnChunk.getBlock())) + .build()); + while (repIterator.hasNext()) { + int next = repIterator.next(); + repetitionLevelWriter.writeInteger(next); + } } updateBufferedBytes();