diff --git a/src/java/net/jpountz/lz4/LZ4BlockInputStream.java b/src/java/net/jpountz/lz4/LZ4BlockInputStream.java index 21ff671c..ddf425fd 100644 --- a/src/java/net/jpountz/lz4/LZ4BlockInputStream.java +++ b/src/java/net/jpountz/lz4/LZ4BlockInputStream.java @@ -29,7 +29,6 @@ import java.util.zip.Checksum; import net.jpountz.util.SafeUtils; -import net.jpountz.util.Utils; import net.jpountz.xxhash.StreamingXXHash32; import net.jpountz.xxhash.XXHash32; import net.jpountz.xxhash.XXHashFactory; @@ -44,6 +43,7 @@ public final class LZ4BlockInputStream extends FilterInputStream { private final LZ4FastDecompressor decompressor; private final Checksum checksum; + private final boolean stopOnEmptyBlock; private byte[] buffer; private byte[] compressedBuffer; private int originalLen; @@ -53,17 +53,19 @@ public final class LZ4BlockInputStream extends FilterInputStream { /** * Create a new {@link InputStream}. * - * @param in the {@link InputStream} to poll - * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to - * use - * @param checksum the {@link Checksum} instance to use, must be - * equivalent to the instance which has been used to - * write the stream + * @param in the {@link InputStream} to poll + * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to + * use + * @param checksum the {@link Checksum} instance to use, must be + * equivalent to the instance which has been used to + * write the stream + * @param stopOnEmptyBlock whether read is stopped on an empty block */ - public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum, boolean stopOnEmptyBlock) { super(in); this.decompressor = decompressor; this.checksum = checksum; + this.stopOnEmptyBlock = stopOnEmptyBlock; this.buffer = new byte[0]; this.compressedBuffer = new byte[HEADER_LENGTH]; o = originalLen = 0; @@ -75,8 +77,26 @@ public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Che * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) * @see StreamingXXHash32#asChecksum() */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { + this(in, decompressor, checksum, true); + } + + /** + * Create a new instance using {@link XXHash32} for checksuming. + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum, boolean) + * @see StreamingXXHash32#asChecksum() + */ public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { - this(in, decompressor, XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); + this(in, decompressor, XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum(), true); + } + + /** + * Create a new instance using {@link XXHash32} for checksuming. + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum, boolean) + * @see StreamingXXHash32#asChecksum() + */ + public LZ4BlockInputStream(InputStream in, boolean stopOnEmptyBlock) { + this(in, LZ4Factory.fastestInstance().fastDecompressor(), XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum(), stopOnEmptyBlock); } /** @@ -147,7 +167,16 @@ public long skip(long n) throws IOException { } private void refill() throws IOException { - readFully(compressedBuffer, HEADER_LENGTH); + try { + readFully(compressedBuffer, HEADER_LENGTH); + } catch (EOFException e) { + if (!stopOnEmptyBlock) { + finished = true; + } else { + throw e; + } + return; + } for (int i = 0; i < MAGIC_LENGTH; ++i) { if (compressedBuffer[i] != MAGIC[i]) { throw new IOException("Stream is corrupted"); @@ -175,7 +204,11 @@ private void refill() throws IOException { if (check != 0) { throw new IOException("Stream is corrupted"); } - finished = true; + if (!stopOnEmptyBlock) { + refill(); + } else { + finished = true; + } return; } if (buffer.length < originalLen) { diff --git a/src/test/net/jpountz/lz4/LZ4BlockStreamingTest.java b/src/test/net/jpountz/lz4/LZ4BlockStreamingTest.java index 3bb491da..903fceb9 100644 --- a/src/test/net/jpountz/lz4/LZ4BlockStreamingTest.java +++ b/src/test/net/jpountz/lz4/LZ4BlockStreamingTest.java @@ -293,4 +293,57 @@ public void testDoubleClose() throws IOException { in.close(); in.close(); } + + private static int readFully(InputStream in, byte[] b) throws IOException { + int total; + int result; + for (total = 0; total < b.length; total += result) { + result = in.read(b, total, b.length - total); + if(result == -1) { + break; + } + } + return total; + } + + @Test + public void testConcatenationOfSerializedStreams() throws IOException { + final byte[] testBytes1 = randomArray(64, 256); + final byte[] testBytes2 = randomArray(64, 256); + byte[] expected = new byte[128]; + System.arraycopy(testBytes1, 0, expected, 0, 64); + System.arraycopy(testBytes2, 0, expected, 64, 64); + + ByteArrayOutputStream bytes1os = new ByteArrayOutputStream(); + LZ4BlockOutputStream out1 = new LZ4BlockOutputStream(bytes1os); + out1.write(testBytes1); + out1.close(); + + ByteArrayOutputStream bytes2os = new ByteArrayOutputStream(); + LZ4BlockOutputStream out2 = new LZ4BlockOutputStream(bytes2os); + out2.write(testBytes2); + out2.close(); + + byte[] bytes1 = bytes1os.toByteArray(); + byte[] bytes2 = bytes2os.toByteArray(); + byte[] concatenatedBytes = new byte[bytes1.length + bytes2.length]; + System.arraycopy(bytes1, 0, concatenatedBytes, 0, bytes1.length); + System.arraycopy(bytes2, 0, concatenatedBytes, bytes1.length, bytes2.length); + + // In a default behaviour, we can read the first block of the concatenated bytes only + LZ4BlockInputStream in1 = new LZ4BlockInputStream(new ByteArrayInputStream(concatenatedBytes)); + byte[] actual1 = new byte[128]; + assertEquals(64, readFully(in1, actual1)); + assertEquals(-1, in1.read()); + in1.close(); + + // Check if we can read concatenated byte stream + LZ4BlockInputStream in2 = new LZ4BlockInputStream(new ByteArrayInputStream(concatenatedBytes), false); + byte[] actual2 = new byte[128]; + assertEquals(128, readFully(in2, actual2)); + assertEquals(-1, in2.read()); + in2.close(); + + assertArrayEquals(expected, actual2); + } }