diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java index 0697feec1ee3..5327c3f2706e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java @@ -32,6 +32,7 @@ import io.trino.parquet.writer.valuewriter.TimestampNanosValueWriter; import io.trino.parquet.writer.valuewriter.TimestampTzMicrosValueWriter; import io.trino.parquet.writer.valuewriter.TimestampTzMillisValueWriter; +import io.trino.parquet.writer.valuewriter.TrinoValuesWriterFactory; import io.trino.parquet.writer.valuewriter.UuidValueWriter; import io.trino.spi.TrinoException; import io.trino.spi.type.CharType; @@ -91,7 +92,8 @@ static List getColumnWriters( CompressionCodec compressionCodec, Optional parquetTimeZone) { - WriteBuilder writeBuilder = new WriteBuilder(messageType, trinoTypes, parquetProperties, compressionCodec, parquetTimeZone); + TrinoValuesWriterFactory valuesWriterFactory = new TrinoValuesWriterFactory(parquetProperties); + WriteBuilder writeBuilder = new WriteBuilder(messageType, trinoTypes, parquetProperties, valuesWriterFactory, compressionCodec, parquetTimeZone); ParquetTypeVisitor.visit(messageType, writeBuilder); return writeBuilder.build(); } @@ -102,6 +104,7 @@ private static class WriteBuilder private final MessageType type; private final Map, Type> trinoTypes; private final ParquetProperties parquetProperties; + private final TrinoValuesWriterFactory valuesWriterFactory; private final CompressionCodec compressionCodec; private final Optional parquetTimeZone; private final ImmutableList.Builder builder = ImmutableList.builder(); @@ -110,12 +113,14 @@ private static class WriteBuilder MessageType messageType, Map, Type> trinoTypes, ParquetProperties parquetProperties, + TrinoValuesWriterFactory valuesWriterFactory, CompressionCodec compressionCodec, Optional parquetTimeZone) { this.type = requireNonNull(messageType, "messageType is null"); this.trinoTypes = requireNonNull(trinoTypes, "trinoTypes is null"); this.parquetProperties = requireNonNull(parquetProperties, "parquetProperties is null"); + this.valuesWriterFactory = requireNonNull(valuesWriterFactory, "valuesWriterFactory is null"); this.compressionCodec = requireNonNull(compressionCodec, "compressionCodec is null"); this.parquetTimeZone = requireNonNull(parquetTimeZone, "parquetTimeZone is null"); } @@ -168,7 +173,7 @@ public ColumnWriter primitive(PrimitiveType primitive) Type trinoType = requireNonNull(trinoTypes.get(ImmutableList.copyOf(path)), "Trino type is null"); return new PrimitiveColumnWriter( columnDescriptor, - getValueWriter(parquetProperties.newValuesWriter(columnDescriptor), trinoType, columnDescriptor.getPrimitiveType(), parquetTimeZone), + getValueWriter(valuesWriterFactory.newValuesWriter(columnDescriptor), trinoType, columnDescriptor.getPrimitiveType(), parquetTimeZone), parquetProperties.newDefinitionLevelWriter(columnDescriptor), parquetProperties.newRepetitionLevelWriter(columnDescriptor), compressionCodec, diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/valuewriter/DictionaryFallbackValuesWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/valuewriter/DictionaryFallbackValuesWriter.java new file mode 100644 index 000000000000..305a591f4fea --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/valuewriter/DictionaryFallbackValuesWriter.java @@ -0,0 +1,234 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.writer.valuewriter; + +import com.google.common.annotations.VisibleForTesting; +import jakarta.annotation.Nullable; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.values.ValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter; +import org.apache.parquet.io.api.Binary; + +import static com.google.common.base.Verify.verify; +import static java.util.Objects.requireNonNull; + +/** + * Based on org.apache.parquet.column.values.fallback.FallbackValuesWriter + */ +public class DictionaryFallbackValuesWriter + extends ValuesWriter +{ + private final ValuesWriter fallBackWriter; + + private boolean fellBackAlready; + private ValuesWriter currentWriter; + @Nullable + private DictionaryValuesWriter initialWriter; + private boolean initialUsedAndHadDictionary; + /* size of raw data, even if dictionary is used, it will not have effect on raw data size, it is used to decide + * if fall back to plain encoding is better by comparing rawDataByteSize with Encoded data size + * It's also used in getBufferedSize, so the page will be written based on raw data size + */ + private long rawDataByteSize; + // indicates if this is the first page being processed + private boolean firstPage = true; + + public DictionaryFallbackValuesWriter(DictionaryValuesWriter initialWriter, ValuesWriter fallBackWriter) + { + super(); + this.initialWriter = initialWriter; + this.fallBackWriter = fallBackWriter; + this.currentWriter = initialWriter; + } + + @Override + public long getBufferedSize() + { + // use raw data size to decide if we want to flush the page + // so the actual size of the page written could be much more smaller + // due to dictionary encoding. This prevents page being too big when fallback happens. + return rawDataByteSize; + } + + @Override + public BytesInput getBytes() + { + if (!fellBackAlready && firstPage) { + // we use the first page to decide if we're going to use this encoding + BytesInput bytes = initialWriter.getBytes(); + if (!initialWriter.isCompressionSatisfying(rawDataByteSize, bytes.size())) { + fallBack(); + // Since fallback happened on first page itself, we can drop the contents of initialWriter + initialWriter.close(); + initialWriter = null; + verify(!initialUsedAndHadDictionary, "initialUsedAndHadDictionary should be false when falling back to PLAIN in first page"); + } + else { + return bytes; + } + } + return currentWriter.getBytes(); + } + + @Override + public Encoding getEncoding() + { + Encoding encoding = currentWriter.getEncoding(); + if (!fellBackAlready && !initialUsedAndHadDictionary) { + initialUsedAndHadDictionary = encoding.usesDictionary(); + } + return encoding; + } + + @Override + public void reset() + { + rawDataByteSize = 0; + firstPage = false; + currentWriter.reset(); + } + + @Override + public void close() + { + if (initialWriter != null) { + initialWriter.close(); + } + fallBackWriter.close(); + } + + @Override + public DictionaryPage toDictPageAndClose() + { + if (initialUsedAndHadDictionary) { + return initialWriter.toDictPageAndClose(); + } + else { + return currentWriter.toDictPageAndClose(); + } + } + + @Override + public void resetDictionary() + { + if (initialUsedAndHadDictionary) { + initialWriter.resetDictionary(); + } + else { + currentWriter.resetDictionary(); + } + currentWriter = initialWriter; + fellBackAlready = false; + initialUsedAndHadDictionary = false; + firstPage = true; + } + + @Override + public long getAllocatedSize() + { + return fallBackWriter.getAllocatedSize() + (initialWriter != null ? initialWriter.getAllocatedSize() : 0); + } + + @Override + public String memUsageString(String prefix) + { + return String.format( + "%s FallbackValuesWriter{\n" + + "%s\n" + + "%s\n" + + "%s}\n", + prefix, + initialWriter != null ? initialWriter.memUsageString(prefix + " initial:") : "", + fallBackWriter.memUsageString(prefix + " fallback:"), + prefix); + } + + // passthrough writing the value + @Override + public void writeByte(int value) + { + rawDataByteSize += Byte.BYTES; + currentWriter.writeByte(value); + checkFallback(); + } + + @Override + public void writeBytes(Binary value) + { + // For raw data, length(4 bytes int) is stored, followed by the binary content itself + rawDataByteSize += value.length() + Integer.BYTES; + currentWriter.writeBytes(value); + checkFallback(); + } + + @Override + public void writeInteger(int value) + { + rawDataByteSize += Integer.BYTES; + currentWriter.writeInteger(value); + checkFallback(); + } + + @Override + public void writeLong(long value) + { + rawDataByteSize += Long.BYTES; + currentWriter.writeLong(value); + checkFallback(); + } + + @Override + public void writeFloat(float value) + { + rawDataByteSize += Float.BYTES; + currentWriter.writeFloat(value); + checkFallback(); + } + + @Override + public void writeDouble(double value) + { + rawDataByteSize += Double.BYTES; + currentWriter.writeDouble(value); + checkFallback(); + } + + @VisibleForTesting + public DictionaryValuesWriter getInitialWriter() + { + return requireNonNull(initialWriter, "initialWriter is null"); + } + + @VisibleForTesting + public ValuesWriter getFallBackWriter() + { + return fallBackWriter; + } + + private void checkFallback() + { + if (!fellBackAlready && initialWriter.shouldFallBack()) { + fallBack(); + } + } + + private void fallBack() + { + fellBackAlready = true; + initialWriter.fallBackAllValuesTo(fallBackWriter); + currentWriter = fallBackWriter; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/valuewriter/TrinoValuesWriterFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/valuewriter/TrinoValuesWriterFactory.java new file mode 100644 index 000000000000..362ee18c3a8f --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/valuewriter/TrinoValuesWriterFactory.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.writer.valuewriter; + +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.column.values.ValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter; +import org.apache.parquet.column.values.plain.BooleanPlainValuesWriter; +import org.apache.parquet.column.values.plain.FixedLenByteArrayPlainValuesWriter; +import org.apache.parquet.column.values.plain.PlainValuesWriter; + +import static org.apache.parquet.column.Encoding.PLAIN_DICTIONARY; + +/** + * Based on org.apache.parquet.column.values.factory.DefaultV1ValuesWriterFactory + */ +public class TrinoValuesWriterFactory +{ + private final ParquetProperties parquetProperties; + + public TrinoValuesWriterFactory(ParquetProperties properties) + { + this.parquetProperties = properties; + } + + public ValuesWriter newValuesWriter(ColumnDescriptor descriptor) + { + return switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { + case BOOLEAN -> new BooleanPlainValuesWriter(); // no dictionary encoding for boolean + case FIXED_LEN_BYTE_ARRAY -> getFixedLenByteArrayValuesWriter(descriptor); + case BINARY -> getBinaryValuesWriter(descriptor); + case INT32 -> getInt32ValuesWriter(descriptor); + case INT64 -> getInt64ValuesWriter(descriptor); + case INT96 -> getInt96ValuesWriter(descriptor); + case DOUBLE -> getDoubleValuesWriter(descriptor); + case FLOAT -> getFloatValuesWriter(descriptor); + }; + } + + private ValuesWriter getFixedLenByteArrayValuesWriter(ColumnDescriptor path) + { + // dictionary encoding was not enabled in PARQUET 1.0 + return new FixedLenByteArrayPlainValuesWriter(path.getTypeLength(), parquetProperties.getInitialSlabSize(), parquetProperties.getPageSizeThreshold(), parquetProperties.getAllocator()); + } + + private ValuesWriter getBinaryValuesWriter(ColumnDescriptor path) + { + ValuesWriter fallbackWriter = new PlainValuesWriter(parquetProperties.getInitialSlabSize(), parquetProperties.getPageSizeThreshold(), parquetProperties.getAllocator()); + return dictWriterWithFallBack(path, parquetProperties, getEncodingForDictionaryPage(), getEncodingForDataPage(), fallbackWriter); + } + + private ValuesWriter getInt32ValuesWriter(ColumnDescriptor path) + { + ValuesWriter fallbackWriter = new PlainValuesWriter(parquetProperties.getInitialSlabSize(), parquetProperties.getPageSizeThreshold(), parquetProperties.getAllocator()); + return dictWriterWithFallBack(path, parquetProperties, getEncodingForDictionaryPage(), getEncodingForDataPage(), fallbackWriter); + } + + private ValuesWriter getInt64ValuesWriter(ColumnDescriptor path) + { + ValuesWriter fallbackWriter = new PlainValuesWriter(parquetProperties.getInitialSlabSize(), parquetProperties.getPageSizeThreshold(), parquetProperties.getAllocator()); + return dictWriterWithFallBack(path, parquetProperties, getEncodingForDictionaryPage(), getEncodingForDataPage(), fallbackWriter); + } + + private ValuesWriter getInt96ValuesWriter(ColumnDescriptor path) + { + ValuesWriter fallbackWriter = new FixedLenByteArrayPlainValuesWriter(12, parquetProperties.getInitialSlabSize(), parquetProperties.getPageSizeThreshold(), parquetProperties.getAllocator()); + return dictWriterWithFallBack(path, parquetProperties, getEncodingForDictionaryPage(), getEncodingForDataPage(), fallbackWriter); + } + + private ValuesWriter getDoubleValuesWriter(ColumnDescriptor path) + { + ValuesWriter fallbackWriter = new PlainValuesWriter(parquetProperties.getInitialSlabSize(), parquetProperties.getPageSizeThreshold(), parquetProperties.getAllocator()); + return dictWriterWithFallBack(path, parquetProperties, getEncodingForDictionaryPage(), getEncodingForDataPage(), fallbackWriter); + } + + private ValuesWriter getFloatValuesWriter(ColumnDescriptor path) + { + ValuesWriter fallbackWriter = new PlainValuesWriter(parquetProperties.getInitialSlabSize(), parquetProperties.getPageSizeThreshold(), parquetProperties.getAllocator()); + return dictWriterWithFallBack(path, parquetProperties, getEncodingForDictionaryPage(), getEncodingForDataPage(), fallbackWriter); + } + + @SuppressWarnings("deprecation") + private static Encoding getEncodingForDataPage() + { + return PLAIN_DICTIONARY; + } + + @SuppressWarnings("deprecation") + private static Encoding getEncodingForDictionaryPage() + { + return PLAIN_DICTIONARY; + } + + private static DictionaryValuesWriter dictionaryWriter(ColumnDescriptor path, ParquetProperties properties, Encoding dictPageEncoding, Encoding dataPageEncoding) + { + return switch (path.getPrimitiveType().getPrimitiveTypeName()) { + case BOOLEAN -> throw new IllegalArgumentException("no dictionary encoding for BOOLEAN"); + case BINARY -> + new DictionaryValuesWriter.PlainBinaryDictionaryValuesWriter(properties.getDictionaryPageSizeThreshold(), dataPageEncoding, dictPageEncoding, properties.getAllocator()); + case INT32 -> + new DictionaryValuesWriter.PlainIntegerDictionaryValuesWriter(properties.getDictionaryPageSizeThreshold(), dataPageEncoding, dictPageEncoding, properties.getAllocator()); + case INT64 -> + new DictionaryValuesWriter.PlainLongDictionaryValuesWriter(properties.getDictionaryPageSizeThreshold(), dataPageEncoding, dictPageEncoding, properties.getAllocator()); + case INT96 -> + new DictionaryValuesWriter.PlainFixedLenArrayDictionaryValuesWriter(properties.getDictionaryPageSizeThreshold(), 12, dataPageEncoding, dictPageEncoding, properties.getAllocator()); + case DOUBLE -> + new DictionaryValuesWriter.PlainDoubleDictionaryValuesWriter(properties.getDictionaryPageSizeThreshold(), dataPageEncoding, dictPageEncoding, properties.getAllocator()); + case FLOAT -> + new DictionaryValuesWriter.PlainFloatDictionaryValuesWriter(properties.getDictionaryPageSizeThreshold(), dataPageEncoding, dictPageEncoding, properties.getAllocator()); + case FIXED_LEN_BYTE_ARRAY -> + new DictionaryValuesWriter.PlainFixedLenArrayDictionaryValuesWriter(properties.getDictionaryPageSizeThreshold(), path.getTypeLength(), dataPageEncoding, dictPageEncoding, properties.getAllocator()); + }; + } + + private static ValuesWriter dictWriterWithFallBack(ColumnDescriptor path, ParquetProperties parquetProperties, Encoding dictPageEncoding, Encoding dataPageEncoding, ValuesWriter writerToFallBackTo) + { + return new DictionaryFallbackValuesWriter(dictionaryWriter(path, parquetProperties, dictPageEncoding, dataPageEncoding), writerToFallBackTo); + } +} diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index 8c2ebe947375..1f073f7e5bfb 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -38,6 +38,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.io.UncheckedIOException; import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -48,6 +49,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.parquet.ParquetTypeUtils.constructField; import static io.trino.parquet.ParquetTypeUtils.getColumnIO; +import static io.trino.parquet.ParquetTypeUtils.getParquetEncoding; import static io.trino.parquet.ParquetTypeUtils.lookupColumnByName; import static io.trino.spi.block.ArrayBlock.fromElementBlock; import static io.trino.spi.block.MapBlock.fromKeyValueBlock; @@ -229,4 +231,17 @@ private static Block generateBlock(Type type, int positions) } return blockBuilder.build(); } + + public static DictionaryPage toTrinoDictionaryPage(org.apache.parquet.column.page.DictionaryPage dictionary) + { + try { + return new DictionaryPage( + Slices.wrappedBuffer(dictionary.getBytes().toByteArray()), + dictionary.getDictionarySize(), + getParquetEncoding(dictionary.getEncoding())); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java index 1322f241082e..36f502a9f1b4 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderRowRangesTest.java @@ -58,9 +58,11 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.parquet.ParquetTestUtils.toTrinoDictionaryPage; import static io.trino.parquet.ParquetTypeUtils.getParquetEncoding; import static io.trino.parquet.reader.FilteredRowRanges.RowRange; -import static io.trino.parquet.reader.TestingColumnReader.toTrinoDictionaryPage; +import static io.trino.parquet.reader.TestingRowRanges.toRowRange; +import static io.trino.parquet.reader.TestingRowRanges.toRowRanges; import static io.trino.testing.DataProviders.cartesianProduct; import static io.trino.testing.DataProviders.concat; import static io.trino.testing.DataProviders.toDataProvider; @@ -68,8 +70,6 @@ import static org.apache.parquet.bytes.BytesUtils.getWidthFromMaxInt; import static org.apache.parquet.column.Encoding.RLE_DICTIONARY; import static org.apache.parquet.format.CompressionCodec.UNCOMPRESSED; -import static org.apache.parquet.internal.filter2.columnindex.TestingRowRanges.toRowRange; -import static org.apache.parquet.internal.filter2.columnindex.TestingRowRanges.toRowRanges; import static org.assertj.core.api.Assertions.assertThat; public abstract class AbstractColumnReaderRowRangesTest diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java index 49dee968e07b..0359a113640e 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/AbstractColumnReaderTest.java @@ -58,9 +58,9 @@ import static io.trino.parquet.reader.AbstractColumnReader.shouldProduceDictionaryForType; import static io.trino.parquet.reader.TestingColumnReader.DataPageVersion.V1; import static io.trino.parquet.reader.TestingColumnReader.getDictionaryPage; +import static io.trino.parquet.reader.TestingRowRanges.toRowRange; import static org.apache.parquet.bytes.BytesUtils.getWidthFromMaxInt; import static org.apache.parquet.format.CompressionCodec.UNCOMPRESSED; -import static org.apache.parquet.internal.filter2.columnindex.TestingRowRanges.toRowRange; import static org.assertj.core.api.Assertions.assertThat; import static org.joda.time.DateTimeZone.UTC; diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java index cec26e429690..1f564a66f68c 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java @@ -60,8 +60,6 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.testng.annotations.DataProvider; -import java.io.IOException; -import java.io.UncheckedIOException; import java.math.BigInteger; import java.time.LocalDateTime; import java.util.Arrays; @@ -71,7 +69,7 @@ import java.util.stream.Stream; import static com.google.common.collect.MoreCollectors.onlyElement; -import static io.trino.parquet.ParquetTypeUtils.getParquetEncoding; +import static io.trino.parquet.ParquetTestUtils.toTrinoDictionaryPage; import static io.trino.parquet.ParquetTypeUtils.paddingBigInteger; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; @@ -595,19 +593,6 @@ public static DictionaryPage getDictionaryPage(DictionaryValuesWriter dictionary return toTrinoDictionaryPage(apacheDictionaryPage); } - public static DictionaryPage toTrinoDictionaryPage(org.apache.parquet.column.page.DictionaryPage dictionary) - { - try { - return new DictionaryPage( - Slices.wrappedBuffer(dictionary.getBytes().toByteArray()), - dictionary.getDictionarySize(), - getParquetEncoding(dictionary.getEncoding())); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - @DataProvider(name = "readersWithPageVersions") public static Object[][] readersWithPageVersions() { diff --git a/lib/trino-parquet/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestingRowRanges.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingRowRanges.java similarity index 95% rename from lib/trino-parquet/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestingRowRanges.java rename to lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingRowRanges.java index 856962a9a205..68d0d20bf55b 100644 --- a/lib/trino-parquet/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestingRowRanges.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingRowRanges.java @@ -11,9 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.parquet.internal.filter2.columnindex; +package io.trino.parquet.reader; import org.apache.parquet.internal.column.columnindex.OffsetIndex; +import org.apache.parquet.internal.filter2.columnindex.RowRanges; import java.util.stream.IntStream; diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/decoders/AbstractValueDecodersTest.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/decoders/AbstractValueDecodersTest.java index 240350d5451a..8c37cefc09f5 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/decoders/AbstractValueDecodersTest.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/decoders/AbstractValueDecodersTest.java @@ -17,10 +17,10 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.parquet.ParquetEncoding; +import io.trino.parquet.ParquetTestUtils; import io.trino.parquet.PrimitiveField; import io.trino.parquet.dictionary.Dictionary; import io.trino.parquet.reader.SimpleSliceInputStream; -import io.trino.parquet.reader.TestingColumnReader; import io.trino.parquet.reader.flat.ColumnAdapter; import io.trino.parquet.reader.flat.DictionaryDecoder; import io.trino.spi.type.DecimalType; @@ -134,7 +134,7 @@ public void testDecoder( DataBuffer dataBuffer = inputDataProvider.write(valuesWriter, dataSize); Optional dictionaryPage = Optional.ofNullable(dataBuffer.dictionaryPage()) - .map(TestingColumnReader::toTrinoDictionaryPage); + .map(ParquetTestUtils::toTrinoDictionaryPage); Optional dictionary = dictionaryPage.map(page -> { try { return encoding.initDictionary(field.getDescriptor(), page); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestRowRangesIterator.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestRowRangesIterator.java index 8861e0d75748..e5f933b9edee 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestRowRangesIterator.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/flat/TestRowRangesIterator.java @@ -21,8 +21,8 @@ import java.util.OptionalLong; import static io.trino.parquet.reader.FilteredRowRanges.RowRange; -import static org.apache.parquet.internal.filter2.columnindex.TestingRowRanges.toRowRange; -import static org.apache.parquet.internal.filter2.columnindex.TestingRowRanges.toRowRanges; +import static io.trino.parquet.reader.TestingRowRanges.toRowRange; +import static io.trino.parquet.reader.TestingRowRanges.toRowRanges; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestTrinoValuesWriterFactory.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestTrinoValuesWriterFactory.java new file mode 100644 index 000000000000..3abec18fe784 --- /dev/null +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestTrinoValuesWriterFactory.java @@ -0,0 +1,151 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.writer; + +import io.trino.parquet.writer.valuewriter.DictionaryFallbackValuesWriter; +import io.trino.parquet.writer.valuewriter.TrinoValuesWriterFactory; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.column.values.ValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainBinaryDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainDoubleDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainFixedLenArrayDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainFloatDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainIntegerDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainLongDictionaryValuesWriter; +import org.apache.parquet.column.values.plain.BooleanPlainValuesWriter; +import org.apache.parquet.column.values.plain.FixedLenByteArrayPlainValuesWriter; +import org.apache.parquet.column.values.plain.PlainValuesWriter; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; +import org.testng.annotations.Test; + +import static java.util.Locale.ENGLISH; +import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0; +import static org.apache.parquet.schema.Types.required; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestTrinoValuesWriterFactory +{ + @Test + public void testBoolean() + { + testValueWriter(PrimitiveTypeName.BOOLEAN, BooleanPlainValuesWriter.class); + } + + @Test + public void testFixedLenByteArray() + { + testValueWriter(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, FixedLenByteArrayPlainValuesWriter.class); + } + + @Test + public void testBinary() + { + testValueWriter( + PrimitiveTypeName.BINARY, + PlainBinaryDictionaryValuesWriter.class, + PlainValuesWriter.class); + } + + @Test + public void testInt32() + { + testValueWriter( + PrimitiveTypeName.INT32, + PlainIntegerDictionaryValuesWriter.class, + PlainValuesWriter.class); + } + + @Test + public void testInt64() + { + testValueWriter( + PrimitiveTypeName.INT64, + PlainLongDictionaryValuesWriter.class, + PlainValuesWriter.class); + } + + @Test + public void testInt96() + { + testValueWriter( + PrimitiveTypeName.INT96, + PlainFixedLenArrayDictionaryValuesWriter.class, + FixedLenByteArrayPlainValuesWriter.class); + } + + @Test + public void testDouble() + { + testValueWriter( + PrimitiveTypeName.DOUBLE, + PlainDoubleDictionaryValuesWriter.class, + PlainValuesWriter.class); + } + + @Test + public void testFloat() + { + testValueWriter( + PrimitiveTypeName.FLOAT, + PlainFloatDictionaryValuesWriter.class, + PlainValuesWriter.class); + } + + private void testValueWriter(PrimitiveTypeName typeName, Class expectedValueWriterClass) + { + ColumnDescriptor mockPath = createColumnDescriptor(typeName); + TrinoValuesWriterFactory factory = new TrinoValuesWriterFactory(ParquetProperties.builder() + .withWriterVersion(PARQUET_1_0) + .build()); + ValuesWriter writer = factory.newValuesWriter(mockPath); + + validateWriterType(writer, expectedValueWriterClass); + } + + private void testValueWriter(PrimitiveTypeName typeName, Class initialValueWriterClass, Class fallbackValueWriterClass) + { + ColumnDescriptor mockPath = createColumnDescriptor(typeName); + TrinoValuesWriterFactory factory = new TrinoValuesWriterFactory(ParquetProperties.builder() + .withWriterVersion(PARQUET_1_0) + .build()); + ValuesWriter writer = factory.newValuesWriter(mockPath); + + validateFallbackWriter(writer, initialValueWriterClass, fallbackValueWriterClass); + } + + private ColumnDescriptor createColumnDescriptor(PrimitiveTypeName typeName) + { + return createColumnDescriptor(typeName, "fake_" + typeName.name().toLowerCase(ENGLISH) + "_col"); + } + + private ColumnDescriptor createColumnDescriptor(PrimitiveTypeName typeName, String name) + { + return new ColumnDescriptor(new String[] {name}, required(typeName).length(1).named(name), 0, 0); + } + + private void validateWriterType(ValuesWriter writer, Class valuesWriterClass) + { + assertThat(writer).isInstanceOf(valuesWriterClass); + } + + private void validateFallbackWriter(ValuesWriter writer, Class initialWriterClass, Class fallbackWriterClass) + { + validateWriterType(writer, DictionaryFallbackValuesWriter.class); + + DictionaryFallbackValuesWriter fallbackValuesWriter = (DictionaryFallbackValuesWriter) writer; + validateWriterType(fallbackValuesWriter.getInitialWriter(), initialWriterClass); + validateWriterType(fallbackValuesWriter.getFallBackWriter(), fallbackWriterClass); + } +} diff --git a/lib/trino-parquet/src/test/java/org/apache/parquet/column/values/dictionary/TestDictionaryWriter.java b/lib/trino-parquet/src/test/java/org/apache/parquet/column/values/dictionary/TestDictionaryWriter.java new file mode 100644 index 000000000000..bc0e1ac50e98 --- /dev/null +++ b/lib/trino-parquet/src/test/java/org/apache/parquet/column/values/dictionary/TestDictionaryWriter.java @@ -0,0 +1,705 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.column.values.dictionary; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.parquet.DictionaryPage; +import io.trino.parquet.reader.SimpleSliceInputStream; +import io.trino.parquet.reader.decoders.PlainByteArrayDecoders; +import io.trino.parquet.reader.decoders.PlainValueDecoders; +import io.trino.parquet.reader.decoders.ValueDecoder; +import io.trino.parquet.reader.flat.BinaryBuffer; +import io.trino.parquet.reader.flat.ColumnAdapter; +import io.trino.parquet.reader.flat.DictionaryDecoder; +import io.trino.parquet.writer.valuewriter.DictionaryFallbackValuesWriter; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.DirectByteBufferAllocator; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.values.ValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainBinaryDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainDoubleDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainFloatDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainIntegerDictionaryValuesWriter; +import org.apache.parquet.column.values.dictionary.DictionaryValuesWriter.PlainLongDictionaryValuesWriter; +import org.apache.parquet.column.values.plain.PlainValuesWriter; +import org.apache.parquet.io.api.Binary; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static io.trino.parquet.ParquetTestUtils.toTrinoDictionaryPage; +import static io.trino.parquet.reader.flat.BinaryColumnAdapter.BINARY_ADAPTER; +import static io.trino.parquet.reader.flat.IntColumnAdapter.INT_ADAPTER; +import static io.trino.parquet.reader.flat.LongColumnAdapter.LONG_ADAPTER; +import static org.apache.parquet.column.Encoding.PLAIN; +import static org.apache.parquet.column.Encoding.PLAIN_DICTIONARY; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; + +public class TestDictionaryWriter +{ + @Test + public void testBinaryDictionary() + throws IOException + { + int count = 100; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainBinaryDictionaryValuesWriter(200, 10000); + writeRepeated(count, fallbackValuesWriter, "a"); + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + writeRepeated(count, fallbackValuesWriter, "b"); + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + // now we will fall back + writeDistinct(count, fallbackValuesWriter, "c"); + BytesInput bytes3 = getBytesAndCheckEncoding(fallbackValuesWriter, PLAIN); + + ValueDecoder decoder = getDictionaryDecoder(fallbackValuesWriter, BINARY_ADAPTER, new PlainByteArrayDecoders.BinaryPlainValueDecoder()); + checkRepeated(count, bytes1, decoder, "a"); + checkRepeated(count, bytes2, decoder, "b"); + decoder = new PlainByteArrayDecoders.BinaryPlainValueDecoder(); + checkDistinct(count, bytes3, decoder, "c"); + } + + @Test + public void testSkipInBinaryDictionary() + throws Exception + { + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainBinaryDictionaryValuesWriter(1000, 10000); + writeRepeated(100, fallbackValuesWriter, "a"); + writeDistinct(100, fallbackValuesWriter, "b"); + assertThat(fallbackValuesWriter.getEncoding()).isEqualTo(getDictionaryEncoding()); + + // Test skip and skip-n with dictionary encoding + Slice writtenValues = Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()); + ValueDecoder decoder = getDictionaryDecoder(fallbackValuesWriter, BINARY_ADAPTER, new PlainByteArrayDecoders.BinaryPlainValueDecoder()); + decoder.init(new SimpleSliceInputStream(writtenValues)); + for (int i = 0; i < 100; i += 2) { + BinaryBuffer buffer = new BinaryBuffer(1); + decoder.read(buffer, 0, 1); + assertThat(buffer.asSlice().toStringUtf8()).isEqualTo("a" + i % 10); + decoder.skip(1); + } + int skipcount; + for (int i = 0; i < 100; i += skipcount + 1) { + BinaryBuffer buffer = new BinaryBuffer(1); + decoder.read(buffer, 0, 1); + assertThat(buffer.asSlice().toStringUtf8()).isEqualTo("b" + i); + skipcount = (100 - i) / 2; + decoder.skip(skipcount); + } + + // Ensure fallback + writeDistinct(1000, fallbackValuesWriter, "c"); + assertThat(fallbackValuesWriter.getEncoding()).isEqualTo(PLAIN); + + // Test skip and skip-n with plain encoding (after fallback) + decoder = new PlainByteArrayDecoders.BinaryPlainValueDecoder(); + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + decoder.skip(200); + for (int i = 0; i < 100; i += 2) { + BinaryBuffer buffer = new BinaryBuffer(1); + decoder.read(buffer, 0, 1); + assertThat(buffer.asSlice().toStringUtf8()).isEqualTo("c" + i); + decoder.skip(1); + } + for (int i = 100; i < 1000; i += skipcount + 1) { + BinaryBuffer buffer = new BinaryBuffer(1); + decoder.read(buffer, 0, 1); + assertThat(buffer.asSlice().toStringUtf8()).isEqualTo("c" + i); + skipcount = (1000 - i) / 2; + decoder.skip(skipcount); + } + } + + @Test + public void testBinaryDictionaryFallBack() + throws IOException + { + int slabSize = 100; + int maxDictionaryByteSize = 50; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainBinaryDictionaryValuesWriter(maxDictionaryByteSize, slabSize); + int dataSize = 0; + for (long i = 0; i < 100; i++) { + Binary binary = Binary.fromString("str" + i); + fallbackValuesWriter.writeBytes(binary); + dataSize += (binary.length() + 4); + assertThat(fallbackValuesWriter.getEncoding()).isEqualTo(dataSize < maxDictionaryByteSize ? getDictionaryEncoding() : PLAIN); + } + + // Fallback to Plain encoding, therefore use BinaryPlainValueDecoder to read it back + ValueDecoder decoder = new PlainByteArrayDecoders.BinaryPlainValueDecoder(); + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + BinaryBuffer buffer = new BinaryBuffer(100); + decoder.read(buffer, 0, 100); + Slice values = buffer.asSlice(); + int[] offsets = buffer.getOffsets(); + int currentOffset = 0; + for (int i = 0; i < 100; i++) { + int length = offsets[i + 1] - offsets[i]; + assertThat(values.slice(currentOffset, length).toStringUtf8()).isEqualTo("str" + i); + currentOffset += length; + } + + // simulate cutting the page + fallbackValuesWriter.reset(); + assertThat(fallbackValuesWriter.getBufferedSize()).isEqualTo(0); + } + + @Test + public void testBinaryDictionaryChangedValues() + throws IOException + { + int count = 100; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainBinaryDictionaryValuesWriter(200, 10000); + writeRepeatedWithReuse(count, fallbackValuesWriter, "a"); + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + writeRepeatedWithReuse(count, fallbackValuesWriter, "b"); + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + // now we will fall back + writeDistinct(count, fallbackValuesWriter, "c"); + BytesInput bytes3 = getBytesAndCheckEncoding(fallbackValuesWriter, PLAIN); + + ValueDecoder decoder = getDictionaryDecoder(fallbackValuesWriter, BINARY_ADAPTER, new PlainByteArrayDecoders.BinaryPlainValueDecoder()); + checkRepeated(count, bytes1, decoder, "a"); + checkRepeated(count, bytes2, decoder, "b"); + decoder = new PlainByteArrayDecoders.BinaryPlainValueDecoder(); + checkDistinct(count, bytes3, decoder, "c"); + } + + @Test + public void testFirstPageFallBack() + throws IOException + { + int count = 1000; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainBinaryDictionaryValuesWriter(10000, 10000); + writeDistinct(count, fallbackValuesWriter, "a"); + long dictionaryAllocatedSize = fallbackValuesWriter.getInitialWriter().getAllocatedSize(); + assertThat(fallbackValuesWriter.getAllocatedSize()).isEqualTo(dictionaryAllocatedSize); + // not efficient so falls back + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, PLAIN); + writeRepeated(count, fallbackValuesWriter, "b"); + assertThat(fallbackValuesWriter.getAllocatedSize()).isEqualTo(fallbackValuesWriter.getFallBackWriter().getAllocatedSize()); + // still plain because we fell back on first page + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, PLAIN); + + ValueDecoder decoder = new PlainByteArrayDecoders.BinaryPlainValueDecoder(); + checkDistinct(count, bytes1, decoder, "a"); + checkRepeated(count, bytes2, decoder, "b"); + } + + @Test + public void testSecondPageFallBack() + throws IOException + { + int count = 1000; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainBinaryDictionaryValuesWriter(1000, 10000); + writeRepeated(count, fallbackValuesWriter, "a"); + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + writeDistinct(count, fallbackValuesWriter, "b"); + // not efficient so falls back + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, PLAIN); + writeRepeated(count, fallbackValuesWriter, "a"); + // still plain because we fell back on previous page + BytesInput bytes3 = getBytesAndCheckEncoding(fallbackValuesWriter, PLAIN); + + ValueDecoder decoder = getDictionaryDecoder(fallbackValuesWriter, BINARY_ADAPTER, new PlainByteArrayDecoders.BinaryPlainValueDecoder()); + checkRepeated(count, bytes1, decoder, "a"); + decoder = new PlainByteArrayDecoders.BinaryPlainValueDecoder(); + checkDistinct(count, bytes2, decoder, "b"); + checkRepeated(count, bytes3, decoder, "a"); + } + + @Test + public void testLongDictionary() + throws IOException + { + int count = 1000; + int count2 = 2000; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainLongDictionaryValuesWriter(10000, 10000); + for (long i = 0; i < count; i++) { + fallbackValuesWriter.writeLong(i % 50); + } + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + for (long i = count2; i > 0; i--) { + fallbackValuesWriter.writeLong(i % 50); + } + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + DictionaryDecoder dictionaryDecoder = getDictionaryDecoder(fallbackValuesWriter, LONG_ADAPTER, new PlainValueDecoders.LongPlainValueDecoder()); + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes1.toByteArray()))); + long[] values = new long[count]; + dictionaryDecoder.read(values, 0, count); + for (int i = 0; i < count; i++) { + assertThat(values[i]).isEqualTo(i % 50); + } + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes2.toByteArray()))); + values = new long[count2]; + dictionaryDecoder.read(values, 0, count2); + for (int i = count2; i > 0; i--) { + assertThat(values[count2 - i]).isEqualTo(i % 50); + } + } + + @Test + public void testLongDictionaryFallBack() + throws IOException + { + int slabSize = 100; + int maxDictionaryByteSize = 50; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainLongDictionaryValuesWriter(maxDictionaryByteSize, slabSize); + // Fallback to Plain encoding, therefore use LongPlainValueDecoder to read it back + ValueDecoder decoder = new PlainValueDecoders.LongPlainValueDecoder(); + + roundTripLong(fallbackValuesWriter, decoder, maxDictionaryByteSize); + // simulate cutting the page + fallbackValuesWriter.reset(); + assertThat(fallbackValuesWriter.getBufferedSize()).isEqualTo(0); + fallbackValuesWriter.resetDictionary(); + + roundTripLong(fallbackValuesWriter, decoder, maxDictionaryByteSize); + } + + @Test + public void testDoubleDictionary() + throws IOException + { + int count = 1000; + int count2 = 2000; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainDoubleDictionaryValuesWriter(10000, 10000); + + for (double i = 0; i < count; i++) { + fallbackValuesWriter.writeDouble(i % 50); + } + + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + for (double i = count2; i > 0; i--) { + fallbackValuesWriter.writeDouble(i % 50); + } + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + DictionaryDecoder dictionaryDecoder = getDictionaryDecoder(fallbackValuesWriter, LONG_ADAPTER, new PlainValueDecoders.LongPlainValueDecoder()); + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes1.toByteArray()))); + long[] values = new long[count]; + dictionaryDecoder.read(values, 0, count); + for (int i = 0; i < count; i++) { + double back = Double.longBitsToDouble(values[i]); + assertThat(back).isEqualTo(i % 50); + } + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes2.toByteArray()))); + values = new long[count2]; + dictionaryDecoder.read(values, 0, count2); + for (int i = count2; i > 0; i--) { + double back = Double.longBitsToDouble(values[count2 - i]); + assertThat(back).isEqualTo(i % 50); + } + } + + @Test + public void testDoubleDictionaryFallBack() + throws IOException + { + int slabSize = 100; + int maxDictionaryByteSize = 50; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainDoubleDictionaryValuesWriter(maxDictionaryByteSize, slabSize); + + // Fallback to Plain encoding, therefore use LongPlainValueDecoder to read it back + ValueDecoder decoder = new PlainValueDecoders.LongPlainValueDecoder(); + + roundTripDouble(fallbackValuesWriter, decoder, maxDictionaryByteSize); + // simulate cutting the page + fallbackValuesWriter.reset(); + assertThat(fallbackValuesWriter.getBufferedSize()).isEqualTo(0); + fallbackValuesWriter.resetDictionary(); + + roundTripDouble(fallbackValuesWriter, decoder, maxDictionaryByteSize); + } + + @Test + public void testIntDictionary() + throws IOException + { + int count = 2000; + int count2 = 4000; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainIntegerDictionaryValuesWriter(10000, 10000); + + for (int i = 0; i < count; i++) { + fallbackValuesWriter.writeInteger(i % 50); + } + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + for (int i = count2; i > 0; i--) { + fallbackValuesWriter.writeInteger(i % 50); + } + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + DictionaryDecoder dictionaryDecoder = getDictionaryDecoder(fallbackValuesWriter, INT_ADAPTER, new PlainValueDecoders.IntPlainValueDecoder()); + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes1.toByteArray()))); + int[] values = new int[count]; + dictionaryDecoder.read(values, 0, count); + for (int i = 0; i < count; i++) { + assertThat(values[i]).isEqualTo(i % 50); + } + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes2.toByteArray()))); + values = new int[count2]; + dictionaryDecoder.read(values, 0, count2); + for (int i = count2; i > 0; i--) { + assertThat(values[count2 - i]).isEqualTo(i % 50); + } + } + + @Test + public void testIntDictionaryFallBack() + throws IOException + { + int slabSize = 100; + int maxDictionaryByteSize = 50; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainIntegerDictionaryValuesWriter(maxDictionaryByteSize, slabSize); + + // Fallback to Plain encoding, therefore use IntPlainValueDecoder to read it back + ValueDecoder decoder = new PlainValueDecoders.IntPlainValueDecoder(); + + roundTripInt(fallbackValuesWriter, decoder, maxDictionaryByteSize); + // simulate cutting the page + fallbackValuesWriter.reset(); + assertThat(fallbackValuesWriter.getBufferedSize()).isEqualTo(0); + fallbackValuesWriter.resetDictionary(); + + roundTripInt(fallbackValuesWriter, decoder, maxDictionaryByteSize); + } + + @Test + public void testFloatDictionary() + throws IOException + { + int count = 2000; + int count2 = 4000; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainFloatDictionaryValuesWriter(10000, 10000); + + for (float i = 0; i < count; i++) { + fallbackValuesWriter.writeFloat(i % 50); + } + BytesInput bytes1 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + for (float i = count2; i > 0; i--) { + fallbackValuesWriter.writeFloat(i % 50); + } + BytesInput bytes2 = getBytesAndCheckEncoding(fallbackValuesWriter, getDictionaryEncoding()); + assertThat(fallbackValuesWriter.getInitialWriter().getDictionarySize()).isEqualTo(50); + + DictionaryDecoder dictionaryDecoder = getDictionaryDecoder(fallbackValuesWriter, INT_ADAPTER, new PlainValueDecoders.IntPlainValueDecoder()); + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes1.toByteArray()))); + int[] values = new int[count]; + dictionaryDecoder.read(values, 0, count); + for (int i = 0; i < count; i++) { + float back = Float.intBitsToFloat(values[i]); + assertThat(back).isEqualTo(i % 50); + } + + dictionaryDecoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes2.toByteArray()))); + values = new int[count2]; + dictionaryDecoder.read(values, 0, count2); + for (int i = count2; i > 0; i--) { + float back = Float.intBitsToFloat(values[count2 - i]); + assertThat(back).isEqualTo(i % 50); + } + } + + @Test + public void testFloatDictionaryFallBack() + throws IOException + { + int slabSize = 100; + int maxDictionaryByteSize = 50; + DictionaryFallbackValuesWriter fallbackValuesWriter = newPlainFloatDictionaryValuesWriter(maxDictionaryByteSize, slabSize); + + // Fallback to Plain encoding, therefore use IntPlainValueDecoder to read it back + ValueDecoder decoder = new PlainValueDecoders.IntPlainValueDecoder(); + + roundTripFloat(fallbackValuesWriter, decoder, maxDictionaryByteSize); + // simulate cutting the page + fallbackValuesWriter.reset(); + assertThat(fallbackValuesWriter.getBufferedSize()).isEqualTo(0); + fallbackValuesWriter.resetDictionary(); + + roundTripFloat(fallbackValuesWriter, decoder, maxDictionaryByteSize); + } + + private void roundTripLong(DictionaryFallbackValuesWriter fallbackValuesWriter, ValueDecoder decoder, int maxDictionaryByteSize) + throws IOException + { + int fallBackThreshold = maxDictionaryByteSize / 8; + for (long i = 0; i < 100; i++) { + fallbackValuesWriter.writeLong(i); + assertThat(fallbackValuesWriter.getEncoding()).isEqualTo(i < fallBackThreshold ? getDictionaryEncoding() : PLAIN); + } + + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + long[] values = new long[100]; + decoder.read(values, 0, 100); + for (int i = 0; i < 100; i++) { + assertThat(values[i]).isEqualTo(i); + } + + // Test skip with plain encoding + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + values = new long[1]; + for (int i = 0; i < 100; i += 2) { + decoder.read(values, 0, 1); + assertThat(values[0]).isEqualTo(i); + decoder.skip(1); + } + + // Test skip-n with plain encoding + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + int skipcount; + for (int i = 0; i < 100; i += skipcount + 1) { + decoder.read(values, 0, 1); + assertThat(values[0]).isEqualTo(i); + skipcount = (100 - i) / 2; + decoder.skip(skipcount); + } + } + + private static void roundTripDouble(DictionaryFallbackValuesWriter fallbackValuesWriter, ValueDecoder decoder, int maxDictionaryByteSize) + throws IOException + { + int fallBackThreshold = maxDictionaryByteSize / 8; + for (double i = 0; i < 100; i++) { + fallbackValuesWriter.writeDouble(i); + assertThat(fallbackValuesWriter.getEncoding()).isEqualTo(i < fallBackThreshold ? getDictionaryEncoding() : PLAIN); + } + + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + long[] values = new long[100]; + decoder.read(values, 0, 100); + for (int i = 0; i < 100; i++) { + assertThat(Double.longBitsToDouble(values[i])).isEqualTo(i); + } + + // Test skip with plain encoding + values = new long[1]; + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + for (int i = 0; i < 100; i += 2) { + decoder.read(values, 0, 1); + assertThat(Double.longBitsToDouble(values[0])).isEqualTo(i); + decoder.skip(1); + } + + // Test skip-n with plain encoding + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + int skipcount; + for (int i = 0; i < 100; i += skipcount + 1) { + decoder.read(values, 0, 1); + assertThat(Double.longBitsToDouble(values[0])).isEqualTo(i); + skipcount = (100 - i) / 2; + decoder.skip(skipcount); + } + } + + private static void roundTripInt(DictionaryFallbackValuesWriter fallbackValuesWriter, ValueDecoder decoder, int maxDictionaryByteSize) + throws IOException + { + int fallBackThreshold = maxDictionaryByteSize / 4; + for (int i = 0; i < 100; i++) { + fallbackValuesWriter.writeInteger(i); + assertThat(fallbackValuesWriter.getEncoding()).isEqualTo(i < fallBackThreshold ? getDictionaryEncoding() : PLAIN); + } + + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + int[] values = new int[100]; + decoder.read(values, 0, 100); + for (int i = 0; i < 100; i++) { + assertThat(values[i]).isEqualTo(i); + } + + // Test skip with plain encoding + values = new int[1]; + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + for (int i = 0; i < 100; i += 2) { + decoder.read(values, 0, 1); + assertThat(values[0]).isEqualTo(i); + decoder.skip(1); + } + + // Test skip-n with plain encoding + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + int skipcount; + for (int i = 0; i < 100; i += skipcount + 1) { + decoder.read(values, 0, 1); + assertThat(values[0]).isEqualTo(i); + skipcount = (100 - i) / 2; + decoder.skip(skipcount); + } + } + + private static void roundTripFloat(DictionaryFallbackValuesWriter fallbackValuesWriter, ValueDecoder decoder, int maxDictionaryByteSize) + throws IOException + { + int fallBackThreshold = maxDictionaryByteSize / 4; + for (float i = 0; i < 100; i++) { + fallbackValuesWriter.writeFloat(i); + assertThat(fallbackValuesWriter.getEncoding()).isEqualTo(i < fallBackThreshold ? getDictionaryEncoding() : PLAIN); + } + + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + int[] values = new int[100]; + decoder.read(values, 0, 100); + for (int i = 0; i < 100; i++) { + assertThat(Float.intBitsToFloat(values[i])).isEqualTo(i); + } + + // Test skip with plain encoding + values = new int[1]; + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + for (int i = 0; i < 100; i += 2) { + decoder.read(values, 0, 1); + assertThat(Float.intBitsToFloat(values[0])).isEqualTo(i); + decoder.skip(1); + } + + // Test skip-n with plain encoding + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(fallbackValuesWriter.getBytes().toByteArray()))); + int skipcount; + for (int i = 0; i < 100; i += skipcount + 1) { + decoder.read(values, 0, 1); + assertThat(Float.intBitsToFloat(values[0])).isEqualTo(i); + skipcount = (100 - i) / 2; + decoder.skip(skipcount); + } + } + + private static DictionaryDecoder getDictionaryDecoder(ValuesWriter valuesWriter, ColumnAdapter columnAdapter, ValueDecoder plainValuesDecoder) + throws IOException + { + DictionaryPage dictionaryPage = toTrinoDictionaryPage(valuesWriter.toDictPageAndClose().copy()); + return DictionaryDecoder.getDictionaryDecoder(dictionaryPage, columnAdapter, plainValuesDecoder, true); + } + + private static void checkDistinct(int count, BytesInput bytes, ValueDecoder decoder, String prefix) + throws IOException + { + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes.toByteArray()))); + BinaryBuffer buffer = new BinaryBuffer(count); + decoder.read(buffer, 0, count); + Slice values = buffer.asSlice(); + int[] offsets = buffer.getOffsets(); + int currentOffset = 0; + for (int i = 0; i < count; i++) { + int length = offsets[i + 1] - offsets[i]; + assertThat(values.slice(currentOffset, length).toStringUtf8()).isEqualTo(prefix + i); + currentOffset += length; + } + } + + private static void checkRepeated(int count, BytesInput bytes, ValueDecoder decoder, String prefix) + throws IOException + { + decoder.init(new SimpleSliceInputStream(Slices.wrappedBuffer(bytes.toByteArray()))); + BinaryBuffer buffer = new BinaryBuffer(count); + decoder.read(buffer, 0, count); + Slice values = buffer.asSlice(); + int[] offsets = buffer.getOffsets(); + int currentOffset = 0; + for (int i = 0; i < count; i++) { + int length = offsets[i + 1] - offsets[i]; + assertThat(values.slice(currentOffset, length).toStringUtf8()).isEqualTo(prefix + i % 10); + currentOffset += length; + } + } + + private static void writeDistinct(int count, ValuesWriter valuesWriter, String prefix) + { + for (int i = 0; i < count; i++) { + valuesWriter.writeBytes(Binary.fromString(prefix + i)); + } + } + + private static void writeRepeated(int count, ValuesWriter valuesWriter, String prefix) + { + for (int i = 0; i < count; i++) { + valuesWriter.writeBytes(Binary.fromString(prefix + i % 10)); + } + } + + private static void writeRepeatedWithReuse(int count, ValuesWriter valuesWriter, String prefix) + { + Binary reused = Binary.fromReusedByteArray((prefix + "0").getBytes(StandardCharsets.UTF_8)); + for (int i = 0; i < count; i++) { + Binary content = Binary.fromString(prefix + i % 10); + System.arraycopy(content.getBytesUnsafe(), 0, reused.getBytesUnsafe(), 0, reused.length()); + valuesWriter.writeBytes(reused); + } + } + + private static BytesInput getBytesAndCheckEncoding(ValuesWriter valuesWriter, Encoding encoding) + throws IOException + { + BytesInput bytes = BytesInput.copy(valuesWriter.getBytes()); + assertThat(valuesWriter.getEncoding()).isEqualTo(encoding); + valuesWriter.reset(); + return bytes; + } + + private static DictionaryFallbackValuesWriter plainFallBack(DictionaryValuesWriter dictionaryValuesWriter, int initialSize) + { + return new DictionaryFallbackValuesWriter(dictionaryValuesWriter, new PlainValuesWriter(initialSize, initialSize * 5, new DirectByteBufferAllocator())); + } + + private static DictionaryFallbackValuesWriter newPlainBinaryDictionaryValuesWriter(int maxDictionaryByteSize, int initialSize) + { + return plainFallBack(new PlainBinaryDictionaryValuesWriter(maxDictionaryByteSize, getDictionaryEncoding(), getDictionaryEncoding(), new DirectByteBufferAllocator()), initialSize); + } + + private static DictionaryFallbackValuesWriter newPlainLongDictionaryValuesWriter(int maxDictionaryByteSize, int initialSize) + { + return plainFallBack(new PlainLongDictionaryValuesWriter(maxDictionaryByteSize, getDictionaryEncoding(), getDictionaryEncoding(), new DirectByteBufferAllocator()), initialSize); + } + + private static DictionaryFallbackValuesWriter newPlainIntegerDictionaryValuesWriter(int maxDictionaryByteSize, int initialSize) + { + return plainFallBack(new PlainIntegerDictionaryValuesWriter(maxDictionaryByteSize, getDictionaryEncoding(), getDictionaryEncoding(), new DirectByteBufferAllocator()), initialSize); + } + + private static DictionaryFallbackValuesWriter newPlainDoubleDictionaryValuesWriter(int maxDictionaryByteSize, int initialSize) + { + return plainFallBack(new PlainDoubleDictionaryValuesWriter(maxDictionaryByteSize, getDictionaryEncoding(), getDictionaryEncoding(), new DirectByteBufferAllocator()), initialSize); + } + + private static DictionaryFallbackValuesWriter newPlainFloatDictionaryValuesWriter(int maxDictionaryByteSize, int initialSize) + { + return plainFallBack(new PlainFloatDictionaryValuesWriter(maxDictionaryByteSize, getDictionaryEncoding(), getDictionaryEncoding(), new DirectByteBufferAllocator()), initialSize); + } + + @SuppressWarnings("deprecation") + private static Encoding getDictionaryEncoding() + { + return PLAIN_DICTIONARY; + } +}