diff --git a/src/main/java/com/github/luben/zstd/Zstd.java b/src/main/java/com/github/luben/zstd/Zstd.java index e79ca2a..2bd4d43 100644 --- a/src/main/java/com/github/luben/zstd/Zstd.java +++ b/src/main/java/com/github/luben/zstd/Zstd.java @@ -418,6 +418,15 @@ public static long decompress(byte[] dst, byte[] src) { } } + public static int decompress(byte[] dst, ByteBuffer srcBuf) { + ZstdDecompressCtx ctx = new ZstdDecompressCtx(); + try { + return ctx.decompress(dst, srcBuf); + } finally { + ctx.close(); + } + } + /** * Decompresses buffer 'src' into buffer 'dst'. * @@ -1343,6 +1352,15 @@ public static int decompress(ByteBuffer dstBuf, ByteBuffer srcBuf) { } } + public static int decompress(ByteBuffer dstBuf, byte[] src) { + ZstdDecompressCtx ctx = new ZstdDecompressCtx(); + try { + return ctx.decompress(dstBuf, src); + } finally { + ctx.close(); + } + } + /** * Decompress data * diff --git a/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java b/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java index 5cbe5a2..8912eae 100644 --- a/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java +++ b/src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java @@ -234,6 +234,60 @@ public int decompressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[ private static native long decompressByteArray0(long nativePtr, byte[] dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize); + public int decompressByteArrayToDirectByteBuffer(ByteBuffer dstBuff, int dstOffset, int dstSize, byte[] srcBuff, int srcOffset, int srcSize) { + if (!dstBuff.isDirect()) { + throw new IllegalArgumentException("dstBuff must be a direct buffer"); + } + + Objects.checkFromIndexSize(srcOffset, srcSize, srcBuff.length); + Objects.checkFromIndexSize(dstOffset, dstSize, dstBuff.limit()); + + ensureOpen(); + acquireSharedLock(); + + try { + long size = decompressByteArrayToDirectByteBuffer0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize); + if (Zstd.isError(size)) { + throw new ZstdException(size); + } + if (size > Integer.MAX_VALUE) { + throw new ZstdException(Zstd.errGeneric(), "Output size is greater than MAX_INT"); + } + return (int) size; + } finally { + releaseSharedLock(); + } + } + + private static native long decompressByteArrayToDirectByteBuffer0(long nativePtr, ByteBuffer dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize); + + public int decompressDirectByteBufferToByteArray(byte[] dstBuff, int dstOffset, int dstSize, ByteBuffer srcBuff, int srcOffset, int srcSize) { + if (!srcBuff.isDirect()) { + throw new IllegalArgumentException("srcBuff must be a direct buffer"); + } + + Objects.checkFromIndexSize(srcOffset, srcSize, srcBuff.limit()); + Objects.checkFromIndexSize(dstOffset, dstSize, dstBuff.length); + + ensureOpen(); + acquireSharedLock(); + + try { + long size = decompressDirectByteBufferToByteArray0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize); + if (Zstd.isError(size)) { + throw new ZstdException(size); + } + if (size > Integer.MAX_VALUE) { + throw new ZstdException(Zstd.errGeneric(), "Output size is greater than MAX_INT"); + } + return (int) size; + } finally { + releaseSharedLock(); + } + } + + private static native long decompressDirectByteBufferToByteArray0(long nativePtr, byte[] dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize); + /* Covenience methods */ /** @@ -267,6 +321,28 @@ public int decompress(ByteBuffer dstBuf, ByteBuffer srcBuf) throws ZstdException return size; } + public int decompress(ByteBuffer dstBuf, byte[] src) throws ZstdException { + int size = decompressByteArrayToDirectByteBuffer(dstBuf, // decompress into dstBuf + dstBuf.position(), // write decompressed data at offset position() + dstBuf.limit() - dstBuf.position(), // write no more than limit() - position() + src, // read compressed data from src + 0, + src.length); + dstBuf.position(dstBuf.position() + size); + return size; + } + + public int decompress(byte[] dst, ByteBuffer srcBuf) throws ZstdException { + int size = decompressDirectByteBufferToByteArray(dst, // decompress into dst + 0, + dst.length, + srcBuf, // read compressed data from srcBuf + srcBuf.position(), // read starting at offset position() + srcBuf.limit() - srcBuf.position()); // read no more than limit() - position() + srcBuf.position(srcBuf.limit()); + return size; + } + public ByteBuffer decompress(ByteBuffer srcBuf, int originalSize) throws ZstdException { ByteBuffer dstBuf = ByteBuffer.allocateDirect(originalSize); int size = decompressDirectByteBuffer(dstBuf, 0, originalSize, srcBuf, srcBuf.position(), srcBuf.limit() - srcBuf.position()); diff --git a/src/main/native/jni_fast_zstd.c b/src/main/native/jni_fast_zstd.c index c488277..4ac71a0 100644 --- a/src/main/native/jni_fast_zstd.c +++ b/src/main/native/jni_fast_zstd.c @@ -689,3 +689,69 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressB E2: (*env)->ReleasePrimitiveArrayCritical(env, dst, dst_buff, 0); E1: return size; } + +/* + * Class: com_github_luben_zstd_ZstdDecompressCtx + * Method: decompressByteArrayToDirectByteBuffer0 + * Signature: (Ljava/nio/ByteBuffer;II[BII)I + */ +JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressByteArrayToDirectByteBuffer0 + (JNIEnv *env, jclass jclazz, jlong ptr, jobject dst, jint dst_offset, jint dst_size, jbyteArray src, jint src_offset, jint src_size) { + size_t size = -ZSTD_error_memory_allocation; + + if (NULL == dst) return -ZSTD_error_dstSize_tooSmall; + if (NULL == src) return -ZSTD_error_srcSize_wrong; + if (0 > dst_offset) return -ZSTD_error_dstSize_tooSmall; + if (0 > src_offset) return -ZSTD_error_srcSize_wrong; + if (0 > src_size) return -ZSTD_error_srcSize_wrong; + + if (src_offset + src_size > (*env)->GetArrayLength(env, src)) return -ZSTD_error_srcSize_wrong; + jsize dst_cap = (*env)->GetDirectBufferCapacity(env, dst); + if (dst_offset + dst_size > dst_cap) return -ZSTD_error_dstSize_tooSmall; + + ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr; + + char *dst_buff = (char*)(*env)->GetDirectBufferAddress(env, dst); + if (dst_buff == NULL) return -ZSTD_error_memory_allocation; + void *src_buff = (*env)->GetPrimitiveArrayCritical(env, src, NULL); + if (src_buff == NULL) goto E1; + + ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only); + size = ZSTD_decompressDCtx(dctx, ((char *)dst_buff) + dst_offset, (size_t) dst_size, ((char *)src_buff) + src_offset, (size_t) src_size); + + (*env)->ReleasePrimitiveArrayCritical(env, src, src_buff, JNI_ABORT); +E1: return size; +} + +/* + * Class: com_github_luben_zstd_ZstdDecompressCtx + * Method: decompressDirectByteBufferToByteArray0 + * Signature: ([BIILjava/nio/ByteBuffer;II)I + */ +JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressDirectByteBufferToByteArray0 + (JNIEnv *env, jclass jclazz, jlong ptr, jbyteArray dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size) { + size_t size = -ZSTD_error_memory_allocation; + + if (NULL == dst) return -ZSTD_error_dstSize_tooSmall; + if (NULL == src) return -ZSTD_error_srcSize_wrong; + if (0 > dst_offset) return -ZSTD_error_dstSize_tooSmall; + if (0 > src_offset) return -ZSTD_error_srcSize_wrong; + if (0 > src_size) return -ZSTD_error_srcSize_wrong; + + if (dst_offset + dst_size > (*env)->GetArrayLength(env, dst)) return -ZSTD_error_dstSize_tooSmall; + jsize src_cap = (*env)->GetDirectBufferCapacity(env, src); + if (src_offset + src_size > src_cap) return -ZSTD_error_srcSize_wrong; + + ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr; + + char *src_buff = (char*)(*env)->GetDirectBufferAddress(env, src); + if (src_buff == NULL) return -ZSTD_error_memory_allocation; + void *dst_buff = (*env)->GetPrimitiveArrayCritical(env, dst, NULL); + if (dst_buff == NULL) goto E1; + + ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only); + size = ZSTD_decompressDCtx(dctx, ((char *)dst_buff) + dst_offset, (size_t) dst_size, ((char *)src_buff) + src_offset, (size_t) src_size); + + (*env)->ReleasePrimitiveArrayCritical(env, dst, dst_buff, 0); +E1: return size; +} diff --git a/src/test/scala/Zstd.scala b/src/test/scala/Zstd.scala index 31d68ea..080fcbf 100644 --- a/src/test/scala/Zstd.scala +++ b/src/test/scala/Zstd.scala @@ -86,13 +86,10 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { val size = input.length val compressed = Zstd.compress(input, level) - val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size.toLong).toInt) - compressedBuffer.put(compressed) - compressedBuffer.limit(compressedBuffer.position()) - compressedBuffer.flip() - - val decompressedBuffer = Zstd.decompress(compressedBuffer, size) - val decompressed = new Array[Byte](size) + val decompressedBuffer = ByteBuffer.allocateDirect(size) + val decompressedSize = Zstd.decompress(decompressedBuffer, compressed); + val decompressed = new Array[Byte](decompressedSize) + decompressedBuffer.flip(); decompressedBuffer.get(decompressed) input.toSeq == decompressed.toSeq } @@ -104,11 +101,10 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { val inputBuffer = ByteBuffer.allocateDirect(size) inputBuffer.put(input) inputBuffer.flip() - val compressedBuffer = Zstd.compress(inputBuffer, level) - val compressed = new Array[Byte](compressedBuffer.limit() - compressedBuffer.position()) - compressedBuffer.get(compressed) + val compressedBuffer = Zstd.compress(inputBuffer, level) - val decompressed = Zstd.decompress(compressed, size) + val decompressed = new Array[Byte](size) + Zstd.decompress(decompressed, compressedBuffer) input.toSeq == decompressed.toSeq } }