From b26822902981dca2ab79eac0b7fee4fae69d3642 Mon Sep 17 00:00:00 2001 From: Zac Blanco Date: Wed, 4 Sep 2024 11:19:32 -0700 Subject: [PATCH] [parquet] Support 64-bit RLE-encoded ShortDecimal Previously, in the parquet writer short decimals could be written as RLE-encoded with an Int64 logical type. However, we lacked support in the reader to decode this type properly back into a short decimal. This commit adds support for the RLE-encoded 64-bit short decimals. --- .../parquet/AbstractTestParquetReader.java | 17 ++ .../batchreader/decoders/Decoders.java | 3 + .../rle/Int64RLEDictionaryValuesDecoder.java | 3 +- .../decoders/TestValuesDecoders.java | 166 ++++++++++++++++++ 4 files changed, 188 insertions(+), 1 deletion(-) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java index d71f243d83f9e..c409071bd332d 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java @@ -939,6 +939,23 @@ public void testDecimalBackedByINT64() } } + @Test + public void testRLEDecimalBackedByINT64() + throws Exception + { + int[] scales = {9, 9, 9, 9, 9, 9, 9, 9, 9}; + for (int precision = MAX_PRECISION_INT32 + 1; precision <= MAX_PRECISION_INT64; precision++) { + int scale = scales[precision - MAX_PRECISION_INT32 - 1]; + MessageType parquetSchema = parseMessageType(format("message hive_decimal { optional INT64 test (DECIMAL(%d, %d)); }", precision, scale)); + ContiguousSet longValues = longsBetween(1, 1_000); + ImmutableList.Builder expectedValues = new ImmutableList.Builder<>(); + for (Long value : longValues) { + expectedValues.add(SqlDecimal.of(value, precision, scale)); + } + tester.testRoundTrip(javaLongObjectInspector, longValues, expectedValues.build(), createDecimalType(precision, scale), Optional.of(parquetSchema)); + } + } + private void testDecimal(int precision, int scale, Optional parquetSchema) throws Exception { diff --git a/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/Decoders.java b/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/Decoders.java index fc181248ecdcf..e698211564b4f 100644 --- a/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/Decoders.java +++ b/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/Decoders.java @@ -179,6 +179,9 @@ private static ValuesDecoder createValuesDecoder(ColumnDescriptor columnDescript if (isTimeStampMicrosType(columnDescriptor) || isTimeMicrosType(columnDescriptor)) { return new Int64TimeAndTimestampMicrosRLEDictionaryValuesDecoder(bitWidth, inputStream, (LongDictionary) dictionary); } + if (isDecimalType(columnDescriptor) && isShortDecimalType(columnDescriptor)) { + return new Int64RLEDictionaryValuesDecoder(bitWidth, inputStream, (LongDictionary) dictionary); + } } case DOUBLE: { return new Int64RLEDictionaryValuesDecoder(bitWidth, inputStream, (LongDictionary) dictionary); diff --git a/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/rle/Int64RLEDictionaryValuesDecoder.java b/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/rle/Int64RLEDictionaryValuesDecoder.java index 8219b28f324a4..51d6da3836278 100644 --- a/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/rle/Int64RLEDictionaryValuesDecoder.java +++ b/presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/rle/Int64RLEDictionaryValuesDecoder.java @@ -14,6 +14,7 @@ package com.facebook.presto.parquet.batchreader.decoders.rle; import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int64ValuesDecoder; +import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.ShortDecimalValuesDecoder; import com.facebook.presto.parquet.dictionary.LongDictionary; import org.apache.parquet.io.ParquetDecodingException; import org.openjdk.jol.info.ClassLayout; @@ -27,7 +28,7 @@ public class Int64RLEDictionaryValuesDecoder extends BaseRLEBitPackedDecoder - implements Int64ValuesDecoder + implements Int64ValuesDecoder, ShortDecimalValuesDecoder { private static final int INSTANCE_SIZE = ClassLayout.parseClass(Int64RLEDictionaryValuesDecoder.class).instanceSize(); diff --git a/presto-parquet/src/test/java/com/facebook/presto/parquet/batchreader/decoders/TestValuesDecoders.java b/presto-parquet/src/test/java/com/facebook/presto/parquet/batchreader/decoders/TestValuesDecoders.java index a853c8d3f217e..35f356f6010c1 100644 --- a/presto-parquet/src/test/java/com/facebook/presto/parquet/batchreader/decoders/TestValuesDecoders.java +++ b/presto-parquet/src/test/java/com/facebook/presto/parquet/batchreader/decoders/TestValuesDecoders.java @@ -19,16 +19,20 @@ import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int32ValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int64TimeAndTimestampMicrosValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int64ValuesDecoder; +import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.ShortDecimalValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.TimestampValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.plain.BinaryPlainValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.plain.BooleanPlainValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.plain.Int32PlainValuesDecoder; +import com.facebook.presto.parquet.batchreader.decoders.plain.Int32ShortDecimalPlainValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.plain.Int64PlainValuesDecoder; +import com.facebook.presto.parquet.batchreader.decoders.plain.Int64ShortDecimalPlainValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.plain.Int64TimeAndTimestampMicrosPlainValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.plain.TimestampPlainValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.rle.BinaryRLEDictionaryValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.rle.BooleanRLEValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.rle.Int32RLEDictionaryValuesDecoder; +import com.facebook.presto.parquet.batchreader.decoders.rle.Int32ShortDecimalRLEDictionaryValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.rle.Int64RLEDictionaryValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.rle.Int64TimeAndTimestampMicrosRLEDictionaryValuesDecoder; import com.facebook.presto.parquet.batchreader.decoders.rle.TimestampRLEDictionaryValuesDecoder; @@ -118,6 +122,26 @@ private static BooleanValuesDecoder booleanRLE(byte[] pageBytes) return new BooleanRLEValuesDecoder(ByteBuffer.wrap(pageBytes)); } + private static ShortDecimalValuesDecoder int32ShortDecimalPlain(byte[] pageBytes) + { + return new Int32ShortDecimalPlainValuesDecoder(pageBytes, 0, pageBytes.length); + } + + private static ShortDecimalValuesDecoder int64ShortDecimalPlain(byte[] pageBytes) + { + return new Int64ShortDecimalPlainValuesDecoder(pageBytes, 0, pageBytes.length); + } + + private static ShortDecimalValuesDecoder int32ShortDecimalRLE(byte[] pageBytes, int dictionarySize, IntegerDictionary dictionary) + { + return new Int32ShortDecimalRLEDictionaryValuesDecoder(getWidthFromMaxInt(dictionarySize), new ByteArrayInputStream(pageBytes), dictionary); + } + + private static ShortDecimalValuesDecoder int64ShortDecimalRLE(byte[] pageBytes, int dictionarySize, LongDictionary dictionary) + { + return new Int64RLEDictionaryValuesDecoder(getWidthFromMaxInt(dictionarySize), new ByteArrayInputStream(pageBytes), dictionary); + } + private static void int32BatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, Int32ValuesDecoder decoder, List expectedValues) throws IOException { @@ -213,6 +237,52 @@ private static void int64BatchReadWithSkipHelper(int batchSize, int skipSize, in } } + private static void int32ShortDecimalBatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, ShortDecimalValuesDecoder decoder, List expectedValues) + throws IOException + { + long[] actualValues = new long[valueCount]; + int inputOffset = 0; + int outputOffset = 0; + while (inputOffset < valueCount) { + int readBatchSize = min(batchSize, valueCount - inputOffset); + decoder.readNext(actualValues, outputOffset, readBatchSize); + + for (int i = 0; i < readBatchSize; i++) { + assertEquals(actualValues[outputOffset + i], (int) expectedValues.get(inputOffset + i)); + } + + inputOffset += readBatchSize; + outputOffset += readBatchSize; + + int skipBatchSize = min(skipSize, valueCount - inputOffset); + decoder.skip(skipBatchSize); + inputOffset += skipBatchSize; + } + } + + private static void int64ShortDecimalBatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, ShortDecimalValuesDecoder decoder, List expectedValues) + throws IOException + { + long[] actualValues = new long[valueCount]; + int inputOffset = 0; + int outputOffset = 0; + while (inputOffset < valueCount) { + int readBatchSize = min(batchSize, valueCount - inputOffset); + decoder.readNext(actualValues, outputOffset, readBatchSize); + + for (int i = 0; i < readBatchSize; i++) { + assertEquals(actualValues[outputOffset + i], expectedValues.get(inputOffset + i)); + } + + inputOffset += readBatchSize; + outputOffset += readBatchSize; + + int skipBatchSize = min(skipSize, valueCount - inputOffset); + decoder.skip(skipBatchSize); + inputOffset += skipBatchSize; + } + } + private static void timestampBatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, TimestampValuesDecoder decoder, List expectedValues) throws IOException { @@ -515,4 +585,100 @@ public void testBooleanRLE() booleanBatchReadWithSkipHelper(89, 29, valueCount, booleanRLE(dataPage), expectedValues); booleanBatchReadWithSkipHelper(1024, 1024, valueCount, booleanRLE(dataPage), expectedValues); } + + @Test + public void testInt32ShortDecimalPlain() + throws IOException + { + int valueCount = 2048; + List expectedValues = new ArrayList<>(); + + byte[] pageBytes = generatePlainValuesPage(valueCount, 32, new Random(83), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); // read all values in one batch + int32ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); + + int32ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); + } + + @Test + public void testInt64ShortDecimalPlain() + throws IOException + { + int valueCount = 2048; + List expectedValues = new ArrayList<>(); + + byte[] pageBytes = generatePlainValuesPage(valueCount, 64, new Random(83), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); // read all values in one batch + int64ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); + + int64ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); + } + + @Test + public void testInt32ShortDecimalRLE() + throws IOException + { + Random random = new Random(83); + int valueCount = 2048; + int dictionarySize = 29; + List dictionary = new ArrayList<>(); + List dictionaryIds = new ArrayList<>(); + + byte[] dictionaryPage = generatePlainValuesPage(dictionarySize, 32, random, dictionary); + byte[] dataPage = generateDictionaryIdPage2048(dictionarySize - 1, random, dictionaryIds); + + List expectedValues = new ArrayList<>(); + for (Integer dictionaryId : dictionaryIds) { + expectedValues.add(dictionary.get(dictionaryId)); + } + + IntegerDictionary integerDictionary = new IntegerDictionary(new DictionaryPage(Slices.wrappedBuffer(dictionaryPage), dictionarySize, PLAIN_DICTIONARY)); + + int32ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); // read all values in one batch + int32ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); + + int32ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); + int32ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); + } + + @Test + public void testInt64ShortDecimalRLE() + throws IOException + { + Random random = new Random(83); + int valueCount = 2048; + int dictionarySize = 29; + List dictionary = new ArrayList<>(); + List dictionaryIds = new ArrayList<>(); + + byte[] dictionaryPage = generatePlainValuesPage(dictionarySize, 64, random, dictionary); + byte[] dataPage = generateDictionaryIdPage2048(dictionarySize - 1, random, dictionaryIds); + + List expectedValues = new ArrayList<>(); + for (Integer dictionaryId : dictionaryIds) { + expectedValues.add(dictionary.get(dictionaryId)); + } + + LongDictionary longDictionary = new LongDictionary(new DictionaryPage(Slices.wrappedBuffer(dictionaryPage), dictionarySize, PLAIN_DICTIONARY)); + + int64ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); // read all values in one batch + int64ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); + + int64ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); + int64ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); + } }