diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcDecompressor.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcDecompressor.java index 773d21f3cab9c..1c046553b36b4 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcDecompressor.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcDecompressor.java @@ -28,6 +28,12 @@ public interface OrcDecompressor static Optional createOrcDecompressor(OrcDataSourceId orcDataSourceId, CompressionKind compression, int bufferSize) throws OrcCorruptionException + { + return createOrcDecompressor(orcDataSourceId, compression, bufferSize, false); + } + + static Optional createOrcDecompressor(OrcDataSourceId orcDataSourceId, CompressionKind compression, int bufferSize, boolean zstdJniDecompressionEnabled) + throws OrcCorruptionException { if ((compression != NONE) && ((bufferSize <= 0) || (bufferSize > MAX_BUFFER_SIZE))) { throw new OrcCorruptionException(orcDataSourceId, "Invalid compression block size: " + bufferSize); @@ -42,7 +48,7 @@ static Optional createOrcDecompressor(OrcDataSourceId orcDataSo case LZ4: return Optional.of(new OrcLz4Decompressor(orcDataSourceId, bufferSize)); case ZSTD: - return Optional.of(new OrcZstdDecompressor(orcDataSourceId, bufferSize)); + return Optional.of(new OrcZstdDecompressor(orcDataSourceId, bufferSize, zstdJniDecompressionEnabled)); default: throw new OrcCorruptionException(orcDataSourceId, "Unknown compression type: " + compression); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java index 621d464168612..c96052f00599a 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java @@ -97,7 +97,7 @@ public OrcReader( OrcFileTail orcFileTail = orcFileTailSource.getOrcFileTail(orcDataSource, metadataReader, writeValidation); this.bufferSize = orcFileTail.getBufferSize(); this.compressionKind = orcFileTail.getCompressionKind(); - this.decompressor = createOrcDecompressor(orcDataSource.getId(), compressionKind, bufferSize); + this.decompressor = createOrcDecompressor(orcDataSource.getId(), compressionKind, bufferSize, orcReaderOptions.isOrcZstdJniDecompressionEnabled()); this.hiveWriterVersion = orcFileTail.getHiveWriterVersion(); try (InputStream footerInputStream = new OrcInputStream(orcDataSource.getId(), orcFileTail.getFooterSlice().getInput(), decompressor, newSimpleAggregatedMemoryContext(), orcFileTail.getFooterSize())) { diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java index 4497c0cf36b26..6a1f5cea3ac8d 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.orc; +import com.github.luben.zstd.Zstd; import io.airlift.compress.MalformedInputException; import io.airlift.compress.zstd.ZstdDecompressor; @@ -24,12 +25,26 @@ class OrcZstdDecompressor { private final OrcDataSourceId orcDataSourceId; private final int maxBufferSize; - private final ZstdDecompressor decompressor = new ZstdDecompressor(); + private final Decompressor decompressor; - public OrcZstdDecompressor(OrcDataSourceId orcDataSourceId, int maxBufferSize) + public OrcZstdDecompressor(OrcDataSourceId orcDataSourceId, int maxBufferSize, boolean zstdJniDecompressionEnabled) { this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSourceId is null"); this.maxBufferSize = maxBufferSize; + if (zstdJniDecompressionEnabled) { + this.decompressor = (input, inputOffset, inputLength, output, outputOffset, maxOutputLength) -> { + long size = Zstd.decompressByteArray(output, 0, maxOutputLength, input, inputOffset, inputLength); + if (Zstd.isError(size)) { + String errorName = Zstd.getErrorName(size); + throw new MalformedInputException(inputOffset, "Zstd JNI decompressor failed with " + errorName); + } + return toIntExact(size); + }; + } + else { + ZstdDecompressor zstdDecompressor = new ZstdDecompressor(); + this.decompressor = zstdDecompressor::decompress; + } } @Override @@ -55,4 +70,10 @@ public String toString() { return "zstd"; } + + interface Decompressor + { + int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) + throws MalformedInputException; + } } diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkBatchStreamReaders.java b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkBatchStreamReaders.java index f02b2def5471a..e61da8527d4c0 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkBatchStreamReaders.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkBatchStreamReaders.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.type.TypeRegistry; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; @@ -45,7 +46,6 @@ import java.io.File; import java.io.IOException; import java.math.BigInteger; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Random; @@ -89,12 +89,12 @@ public Object readBlocks(BenchmarkData data) throws Throwable { OrcBatchRecordReader recordReader = data.createRecordReader(); - List blocks = new ArrayList<>(); + ImmutableList.Builder blocks = new ImmutableList.Builder<>(); while (recordReader.nextBatch() > 0) { Block block = recordReader.readBlock(0); blocks.add(block); } - return blocks; + return blocks.build(); } @State(Scope.Thread) @@ -154,7 +154,7 @@ public void tearDown() deleteRecursively(temporaryDirectory.toPath(), ALLOW_INSECURE); } - protected final List createValues() + protected List createValues() { switch (withNulls) { case ALL: @@ -166,7 +166,7 @@ protected final List createValues() } } - private final Object createValue() + private Object createValue() { switch (typeSignature) { case "boolean": diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkBatchStreamReadersWithZstd.java b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkBatchStreamReadersWithZstd.java new file mode 100644 index 0000000000000..9a43b83e60330 --- /dev/null +++ b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkBatchStreamReadersWithZstd.java @@ -0,0 +1,256 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.orc; + +import com.facebook.presto.orc.cache.StorageOrcFileTailSource; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.SqlDecimal; +import com.facebook.presto.spi.type.SqlTimestamp; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.type.TypeRegistry; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.units.DataSize; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.io.File; +import java.io.IOException; +import java.math.BigInteger; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static com.facebook.presto.orc.OrcReader.INITIAL_BATCH_SIZE; +import static com.facebook.presto.orc.OrcTester.Format.DWRF; +import static com.facebook.presto.orc.OrcTester.writeOrcColumnPresto; +import static com.facebook.presto.orc.metadata.CompressionKind.ZSTD; +import static com.facebook.presto.spi.type.DecimalType.createDecimalType; +import static com.facebook.presto.spi.type.TimeZoneKey.UTC_KEY; +import static com.google.common.io.Files.createTempDir; +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static java.util.UUID.randomUUID; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.stream.Collectors.toList; +import static org.joda.time.DateTimeZone.UTC; + +@SuppressWarnings("MethodMayBeStatic") +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.SECONDS) +@Fork(3) +@Warmup(iterations = 20, time = 500, timeUnit = MILLISECONDS) +@Measurement(iterations = 20, time = 500, timeUnit = MILLISECONDS) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkBatchStreamReadersWithZstd +{ + private static final DecimalType SHORT_DECIMAL_TYPE = createDecimalType(10, 5); + private static final DecimalType LONG_DECIMAL_TYPE = createDecimalType(30, 10); + private static final int ROWS = 10_000_000; + private static final int MAX_STRING = 10; + private static final List NULL_VALUES = Collections.nCopies(ROWS, null); + + @Benchmark + public Object readBlocksWithoutJni(BenchmarkData data) + throws Throwable + { + OrcBatchRecordReader recordReader = data.createRecordReader(false); + ImmutableList.Builder blocks = new ImmutableList.Builder<>(); + while (recordReader.nextBatch() > 0) { + Block block = recordReader.readBlock(0); + blocks.add(block); + } + return blocks.build(); + } + + @Benchmark + public Object readBlocksWithJni(BenchmarkData data) + throws Throwable + { + OrcBatchRecordReader recordReader = data.createRecordReader(true); + ImmutableList.Builder blocks = new ImmutableList.Builder<>(); + while (recordReader.nextBatch() > 0) { + Block block = recordReader.readBlock(0); + blocks.add(block); + } + return blocks.build(); + } + + @State(Scope.Thread) + public static class BenchmarkData + { + private final Random random = new Random(0); + + private Type type; + private File temporaryDirectory; + private File orcFile; + private final OrcTester.Format format = DWRF; + + @SuppressWarnings("unused") + @Param({ + "boolean", + "tinyint", + "smallint", + "integer", + "bigint", + "decimal(10,5)", + "decimal(30,10)", + "timestamp", + "real", + "double", + "varchar_direct", + "varchar_dictionary", + }) + private String typeSignature; + + @SuppressWarnings("unused") + @Param({ + "PARTIAL", + "NONE", + "ALL" + }) + private Nulls withNulls; + + @Setup + public void setup() + throws Exception + { + if (typeSignature.startsWith("varchar")) { + type = new TypeRegistry().getType(TypeSignature.parseTypeSignature("varchar")); + } + else { + type = new TypeRegistry().getType(TypeSignature.parseTypeSignature(typeSignature)); + } + + temporaryDirectory = createTempDir(); + orcFile = new File(temporaryDirectory, randomUUID().toString()); + writeOrcColumnPresto(orcFile, format, ZSTD, type, createValues()); + } + + @TearDown + public void tearDown() + throws IOException + { + deleteRecursively(temporaryDirectory.toPath(), ALLOW_INSECURE); + } + + protected List createValues() + { + switch (withNulls) { + case ALL: + return NULL_VALUES; + case PARTIAL: + return IntStream.range(0, ROWS).mapToObj(i -> i % 2 == 0 ? createValue() : null).collect(toList()); + default: + return IntStream.range(0, ROWS).mapToObj(i -> createValue()).collect(toList()); + } + } + + private Object createValue() + { + switch (typeSignature) { + case "boolean": + return random.nextBoolean(); + case "tinyint": + return Long.valueOf(random.nextLong()).byteValue(); + case "smallint": + return (short) random.nextInt(); + case "integer": + return random.nextInt(); + case "bigint": + return random.nextLong(); + case "decimal(10,5)": + return new SqlDecimal(BigInteger.valueOf(random.nextLong() % 10_000_000_000L), SHORT_DECIMAL_TYPE.getPrecision(), SHORT_DECIMAL_TYPE.getScale()); + case "decimal(30,10)": + return new SqlDecimal(BigInteger.valueOf(random.nextLong() % 10_000_000_000L), LONG_DECIMAL_TYPE.getPrecision(), LONG_DECIMAL_TYPE.getScale()); + case "timestamp": + return new SqlTimestamp((random.nextLong()), UTC_KEY); + case "real": + return random.nextFloat(); + case "double": + return random.nextDouble(); + case "varchar_dictionary": + return Strings.repeat("0", MAX_STRING); + case "varchar_direct": + return randomAsciiString(random); + } + + throw new UnsupportedOperationException("Unsupported type: " + typeSignature); + } + + private OrcBatchRecordReader createRecordReader(boolean zstdJniDecompressionEnabled) + throws IOException + { + OrcDataSource dataSource = new FileOrcDataSource(orcFile, new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), true); + OrcReader orcReader = new OrcReader( + dataSource, + format.getOrcEncoding(), + new StorageOrcFileTailSource(), + new StorageStripeMetadataSource(), + OrcReaderTestingUtils.createTestingReaderOptions(zstdJniDecompressionEnabled)); + return orcReader.createBatchRecordReader( + ImmutableMap.of(0, type), + OrcPredicate.TRUE, + UTC, // arbitrary + newSimpleAggregatedMemoryContext(), + INITIAL_BATCH_SIZE); + } + + private static String randomAsciiString(Random random) + { + char[] value = new char[random.nextInt(MAX_STRING)]; + for (int i = 0; i < value.length; i++) { + value[i] = (char) random.nextInt(Byte.MAX_VALUE); + } + return new String(value); + } + + public enum Nulls + { + PARTIAL, NONE, ALL; + } + } + + public static void main(String[] args) + throws Throwable + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkBatchStreamReadersWithZstd.class.getSimpleName() + ".*") + .build(); + + new Runner(options).run(); + } +} diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkZstdJniDecompression.java b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkZstdJniDecompression.java new file mode 100644 index 0000000000000..28c9f56f90482 --- /dev/null +++ b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkZstdJniDecompression.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.orc; + +import com.facebook.presto.orc.zstd.ZstdJniCompressor; +import com.google.common.collect.ImmutableList; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.List; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.testng.Assert.assertEquals; + +@SuppressWarnings("MethodMayBeStatic") +@State(Scope.Thread) +@OutputTimeUnit(MILLISECONDS) +@Fork(3) +@Warmup(iterations = 20, time = 500, timeUnit = MILLISECONDS) +@Measurement(iterations = 20, time = 500, timeUnit = MILLISECONDS) +@BenchmarkMode(Mode.AverageTime) + +public class BenchmarkZstdJniDecompression +{ + private static final ZstdJniCompressor compressor = new ZstdJniCompressor(); + private static final List list = generateWorkload(); + private static final int sourceLength = 256 * 1024; + private static byte[] decompressedBytes = new byte[sourceLength]; + + @Benchmark + public void decompressJni() + throws OrcCorruptionException + { + decompressList(createOrcDecompressor(true)); + } + + @Benchmark + public void decompressJava() + throws OrcCorruptionException + { + decompressList(createOrcDecompressor(false)); + } + + private void decompressList(OrcDecompressor decompressor) + throws OrcCorruptionException + { + for (Unit unit : list) { + int outputSize = decompressor.decompress(unit.compressedBytes, 0, unit.compressedLength, new OrcDecompressor.OutputBuffer() + { + @Override + public byte[] initialize(int size) + { + return decompressedBytes; + } + + @Override + public byte[] grow(int size) + { + throw new RuntimeException(); + } + }); + assertEquals(outputSize, unit.sourceLength); + } + } + + private static List generateWorkload() + { + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + for (int i = 0; i < 10; i++) { + byte[] sourceBytes = getAlphaNumericString(sourceLength).getBytes(); + byte[] compressedBytes = new byte[sourceLength * 32]; + int size = compressor.compress(sourceBytes, 0, sourceBytes.length, compressedBytes, 0, compressedBytes.length); + builder.add(new Unit(sourceBytes, sourceLength, compressedBytes, size)); + } + return builder.build(); + } + + private OrcDecompressor createOrcDecompressor(boolean zstdJniDecompressionEnabled) + { + return new OrcZstdDecompressor(new OrcDataSourceId("orc"), sourceLength, zstdJniDecompressionEnabled); + } + + private static String getAlphaNumericString(int length) + { + String alphaNumericString = "USINDIA"; + + StringBuilder stringBuilder = new StringBuilder(length); + + for (int index = 0; index < length; index++) { + int arrayIndex = (int) (alphaNumericString.length() * Math.random()); + + stringBuilder.append(alphaNumericString.charAt(arrayIndex)); + } + return stringBuilder.toString(); + } + + static class Unit + { + final byte[] sourceBytes; + final int sourceLength; + final byte[] compressedBytes; + final int compressedLength; + + public Unit(byte[] sourceBytes, int sourceLength, byte[] compressedBytes, int compressedLength) + { + this.sourceBytes = sourceBytes; + this.sourceLength = sourceLength; + this.compressedBytes = compressedBytes; + this.compressedLength = compressedLength; + } + } +} diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/OrcReaderTestingUtils.java b/presto-orc/src/test/java/com/facebook/presto/orc/OrcReaderTestingUtils.java index 5a95ba093039b..2b91af7b2bf0c 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/OrcReaderTestingUtils.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/OrcReaderTestingUtils.java @@ -22,11 +22,16 @@ public class OrcReaderTestingUtils private OrcReaderTestingUtils() {} public static OrcReaderOptions createDefaultTestConfig() + { + return createTestingReaderOptions(false); + } + + public static OrcReaderOptions createTestingReaderOptions(boolean zstdJniDecompressionEnabled) { return new OrcReaderOptions( new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), - false); + zstdJniDecompressionEnabled); } } diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java b/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java index b7aff6ba4fa4a..34b1d994a6c28 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java @@ -1313,6 +1313,12 @@ static OrcBatchRecordReader createCustomOrcRecordReader( return orcReader.createBatchRecordReader(columnTypes, predicate, HIVE_STORAGE_TIME_ZONE, newSimpleAggregatedMemoryContext(), initialBatchSize); } + public static void writeOrcColumnPresto(File outputFile, Format format, CompressionKind compression, Type type, List values) + throws Exception + { + writeOrcColumnsPresto(outputFile, format, compression, ImmutableList.of(type), ImmutableList.of(values), new OrcWriterStats()); + } + private static void writeOrcColumnsPresto(File outputFile, Format format, CompressionKind compression, List types, List> values, OrcWriterStats stats) throws Exception { diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/TestZstdJniDecompression.java b/presto-orc/src/test/java/com/facebook/presto/orc/TestZstdJniDecompression.java new file mode 100644 index 0000000000000..767ffffd742e1 --- /dev/null +++ b/presto-orc/src/test/java/com/facebook/presto/orc/TestZstdJniDecompression.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.orc; + +import com.facebook.presto.orc.zstd.ZstdJniCompressor; +import com.facebook.presto.testing.assertions.Assert; +import io.airlift.units.DataSize; +import org.testng.annotations.Test; + +import java.util.Random; + +public class TestZstdJniDecompression +{ + private static final DataSize MAX_BUFFER_SIZE = new DataSize(4, DataSize.Unit.MEGABYTE); + private final ZstdJniCompressor compressor = new ZstdJniCompressor(); + private final OrcZstdDecompressor decompressor = new OrcZstdDecompressor(new OrcDataSourceId("test"), (int) MAX_BUFFER_SIZE.toBytes(), true); + + @Test + public void testDecompression() + throws OrcCorruptionException + { + byte[] sourceBytes = generateRandomBytes(); + byte[] compressedBytes = new byte[1024 * 1024]; + int size = compressor.compress(sourceBytes, 0, sourceBytes.length, compressedBytes, 0, compressedBytes.length); + byte[] output = new byte[sourceBytes.length]; + int outputSize = decompressor.decompress( + compressedBytes, + 0, + size, + new OrcDecompressor.OutputBuffer() + { + @Override + public byte[] initialize(int size) + { + return output; + } + + @Override + public byte[] grow(int size) + { + throw new RuntimeException(); + } + }); + Assert.assertEquals(outputSize, sourceBytes.length); + Assert.assertEquals(output, sourceBytes); + } + + private byte[] generateRandomBytes() + { + Random random = new Random(); + byte[] array = new byte[1024]; + random.nextBytes(array); + return array; + } +}