From e6704006fecfb0653851de698d85164d9837207a Mon Sep 17 00:00:00 2001 From: Richard DiCroce Date: Wed, 12 Oct 2022 11:12:43 -0400 Subject: [PATCH] Fix #151: add support for magicless frames --- src/main/java/com/github/luben/zstd/Zstd.java | 38 +++++++++++++++++-- .../github/luben/zstd/ZstdCompressCtx.java | 14 +++++++ .../github/luben/zstd/ZstdDecompressCtx.java | 14 +++++++ src/main/native/jni_zstd.c | 36 ++++++++++++++++-- src/test/scala/Zstd.scala | 22 +++++++++++ 5 files changed, 116 insertions(+), 8 deletions(-) diff --git a/src/main/java/com/github/luben/zstd/Zstd.java b/src/main/java/com/github/luben/zstd/Zstd.java index ebfac7d..dfb62b3 100644 --- a/src/main/java/com/github/luben/zstd/Zstd.java +++ b/src/main/java/com/github/luben/zstd/Zstd.java @@ -561,10 +561,12 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff public static native int loadDictCompress(long stream, byte[] dict, int dict_size); public static native int loadFastDictCompress(long stream, ZstdDictCompress dict); public static native int setCompressionChecksums(long stream, boolean useChecksums); + public static native int setCompressionMagicless(long stream, boolean useMagicless); public static native int setCompressionLevel(long stream, int level); public static native int setCompressionLong(long stream, int windowLog); public static native int setCompressionWorkers(long stream, int workers); public static native int setDecompressionLongMax(long stream, int windowLogMax); + public static native int setDecompressionMagicless(long stream, boolean useMagicless); /* Utility methods */ @@ -574,21 +576,35 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff * @param src the compressed buffer * @param srcPosition offset of the compressed data inside the src buffer * @param srcSize length of the compressed data inside the src buffer + * @param magicless whether the buffer contains a magicless frame * @return the number of bytes of the original buffer * 0 if the original size is not known */ - public static long decompressedSize(byte[] src, int srcPosition, int srcSize) { + public static long decompressedSize(byte[] src, int srcPosition, int srcSize, boolean magicless) { if (srcPosition >= src.length) { throw new ArrayIndexOutOfBoundsException(srcPosition); } if (srcPosition + srcSize > src.length) { throw new ArrayIndexOutOfBoundsException(srcPosition + srcSize); } - return decompressedSize0(src, srcPosition, srcSize); + return decompressedSize0(src, srcPosition, srcSize, magicless); } - private static native long decompressedSize0(byte[] src, int srcPosition, int srcSize); + private static native long decompressedSize0(byte[] src, int srcPosition, int srcSize, boolean magicless); + /** + * Return the original size of a compressed buffer (if known) + * + * @param src the compressed buffer + * @param srcPosition offset of the compressed data inside the src buffer + * @param srcSize length of the compressed data inside the src buffer + * @return the number of bytes of the original buffer + * 0 if the original size is not known + */ + public static long decompressedSize(byte[] src, int srcPosition, int srcSize) { + return decompressedSize(src, srcPosition, srcSize, false); + } + /** * Return the original size of a compressed buffer (if known) * @@ -618,10 +634,24 @@ public static long decompressedSize(byte[] src) { * @param src the compressed buffer * @param srcPosition offset of the compressed data inside the src buffer * @param srcSize length of the compressed data inside the src buffe + * @param magicless whether the buffer contains a magicless frame * @return the number of bytes of the original buffer * 0 if the original size is not known */ - public static native long decompressedDirectByteBufferSize(ByteBuffer src, int srcPosition, int srcSize); + public static native long decompressedDirectByteBufferSize(ByteBuffer src, int srcPosition, int srcSize, boolean magicless); + + /** + * Return the original size of a compressed buffer (if known) + * + * @param src the compressed buffer + * @param srcPosition offset of the compressed data inside the src buffer + * @param srcSize length of the compressed data inside the src buffe + * @return the number of bytes of the original buffer + * 0 if the original size is not known + */ + public static long decompressedDirectByteBufferSize(ByteBuffer src, int srcPosition, int srcSize) { + return decompressedDirectByteBufferSize(src, srcPosition, srcSize, false); + } /** * Maximum size of the compressed data diff --git a/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java b/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java index 470bfad..a741f7b 100644 --- a/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java +++ b/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java @@ -55,6 +55,20 @@ public ZstdCompressCtx setLevel(int level) { private native void setLevel0(int level); + /** + * Enable or disable magicless frames + * @param magiclessFlag A 32-bits magic number is written at start of frame, default: false + */ + public ZstdCompressCtx setMagicless(boolean magiclessFlag) { + if (nativePtr == 0) { + throw new IllegalStateException("Compression context is closed"); + } + acquireSharedLock(); + Zstd.setCompressionMagicless(nativePtr, magiclessFlag); + releaseSharedLock(); + return this; + } + /** * Enable or disable compression checksums * @param checksumFlag A 32-bits checksum of content is written at end of frame, default: false diff --git a/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java b/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java index 46be5f4..8fba020 100644 --- a/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java +++ b/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java @@ -38,6 +38,20 @@ void doClose() { } } + /** + * Enable or disable magicless frames + * @param magiclessFlag A 32-bits checksum of content is written at end of frame, default: false + */ + public ZstdDecompressCtx setMagicless(boolean magiclessFlag) { + if (nativePtr == 0) { + throw new IllegalStateException("Compression context is closed"); + } + acquireSharedLock(); + Zstd.setDecompressionMagicless(nativePtr, magiclessFlag); + releaseSharedLock(); + return this; + } + /** * Load decompression dictionary * diff --git a/src/main/native/jni_zstd.c b/src/main/native/jni_zstd.c index 28981d0..812fd33 100644 --- a/src/main/native/jni_zstd.c +++ b/src/main/native/jni_zstd.c @@ -26,6 +26,22 @@ static size_t JNI_ZSTD_compress(void* dst, size_t dstCapacity, return size; } +/* + * Helper for determining decompressed size + */ +static size_t JNI_ZSTD_decompressedSize(const void* buf, size_t bufSize, jboolean magicless) { + if (magicless) { + ZSTD_frameHeader frameHeader; + if (ZSTD_getFrameHeader_advanced(&frameHeader, buf, bufSize, ZSTD_f_zstd1_magicless) != 0) { + return 0; + } + // note that skippable frames must have a magic number, so we don't need to consider that here + return frameHeader.frameContentSize; + } + + return ZSTD_getDecompressedSize(buf, bufSize); +} + /* * Class: com_github_luben_zstd_Zstd * Method: compressUnsafe @@ -52,11 +68,11 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_decompressUnsafe * Signature: ([B)JII */ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_decompressedSize0 - (JNIEnv *env, jclass obj, jbyteArray src, jint offset, jint limit) { + (JNIEnv *env, jclass obj, jbyteArray src, jint offset, jint limit, jboolean magicless) { size_t size = -ZSTD_error_memory_allocation; void *src_buff = (*env)->GetPrimitiveArrayCritical(env, src, NULL); if (src_buff == NULL) goto E1; - size = ZSTD_getDecompressedSize(((char *) src_buff) + offset, (size_t) limit); + size = JNI_ZSTD_decompressedSize(((char *) src_buff) + offset, (size_t) limit, magicless); (*env)->ReleasePrimitiveArrayCritical(env, src, src_buff, JNI_ABORT); E1: return size; } @@ -115,13 +131,13 @@ E1: return (jlong) dict_id; * Signature: (Ljava/nio/ByteBuffer;II)J */ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_decompressedDirectByteBufferSize - (JNIEnv *env, jclass obj, jobject src_buf, jint src_offset, jint src_size) { + (JNIEnv *env, jclass obj, jobject src_buf, jint src_offset, jint src_size, jboolean magicless) { size_t size = -ZSTD_error_memory_allocation; jsize src_cap = (*env)->GetDirectBufferCapacity(env, src_buf); if (src_offset + src_size > src_cap) return -ZSTD_error_GENERIC; char *src_buf_ptr = (char*)(*env)->GetDirectBufferAddress(env, src_buf); if (src_buf_ptr == NULL) goto E1; - size = ZSTD_getDecompressedSize(src_buf_ptr + src_offset, (size_t) src_size); + size = JNI_ZSTD_decompressedSize(src_buf_ptr + src_offset, (size_t) src_size, magicless); E1: return size; } @@ -240,6 +256,12 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setCompressionChecksums return ZSTD_CCtx_setParameter((ZSTD_CCtx *)(intptr_t) stream, ZSTD_c_checksumFlag, checksum); } +JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setCompressionMagicless + (JNIEnv *env, jclass obj, jlong stream, jboolean enabled) { + ZSTD_format_e format = enabled ? ZSTD_f_zstd1_magicless : ZSTD_f_zstd1; + return ZSTD_CCtx_setParameter((ZSTD_CCtx *)(intptr_t) stream, ZSTD_c_format, format); +} + /* * Class: com_github_luben_zstd_Zstd * Method: setCompressionLevel @@ -280,6 +302,12 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setDecompressionLongMax return ZSTD_DCtx_setParameter(dctx, ZSTD_d_windowLogMax, windowLogMax); } +JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setDecompressionMagicless + (JNIEnv *env, jclass obj, jlong stream, jboolean enabled) { + ZSTD_format_e format = enabled ? ZSTD_f_zstd1_magicless : ZSTD_f_zstd1; + return ZSTD_DCtx_setParameter((ZSTD_DCtx *)(intptr_t) stream, ZSTD_d_format, format); +} + /* * Class: com_github_luben_zstd_Zstd * Method: setCompressionWorkers diff --git a/src/test/scala/Zstd.scala b/src/test/scala/Zstd.scala index 6e80692..0c8cd50 100644 --- a/src/test/scala/Zstd.scala +++ b/src/test/scala/Zstd.scala @@ -919,4 +919,26 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { } }.get } + + "magicless frames" should "be magicless and roundtrip" in { + Using.Manager { use => + val cctx = use(new ZstdCompressCtx()) + val dctx = use(new ZstdDecompressCtx()) + forAll { input: Array[Byte] => + { + cctx.reset() + val compressedMagic = cctx.compress(input) + cctx.setMagicless(true) + val compressedMagicless = cctx.compress(input) + assert(compressedMagicless.length == (compressedMagic.length - 4)) + assert(input.length == Zstd.decompressedSize(compressedMagicless, 0, compressedMagicless.length, true)) + + dctx.reset() + dctx.setMagicless(true) + val decompressed = dctx.decompress(compressedMagicless, input.length) + assert(input.toSeq == decompressed.toSeq) + } + } + }.get + } }