Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/main/java/com/github/luben/zstd/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,15 @@ public static long decompress(byte[] dst, byte[] src) {
}
}

public static int decompress(byte[] dst, ByteBuffer srcBuf) {
ZstdDecompressCtx ctx = new ZstdDecompressCtx();
try {
return ctx.decompress(dst, srcBuf);
} finally {
ctx.close();
}
}

/**
* Decompresses buffer 'src' into buffer 'dst'.
*
Expand Down Expand Up @@ -1343,6 +1352,15 @@ public static int decompress(ByteBuffer dstBuf, ByteBuffer srcBuf) {
}
}

public static int decompress(ByteBuffer dstBuf, byte[] src) {
ZstdDecompressCtx ctx = new ZstdDecompressCtx();
try {
return ctx.decompress(dstBuf, src);
} finally {
ctx.close();
}
}

/**
* Decompress data
*
Expand Down
76 changes: 76 additions & 0 deletions src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,60 @@ public int decompressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[

private static native long decompressByteArray0(long nativePtr, byte[] dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize);

public int decompressByteArrayToDirectByteBuffer(ByteBuffer dstBuff, int dstOffset, int dstSize, byte[] srcBuff, int srcOffset, int srcSize) {
if (!dstBuff.isDirect()) {
throw new IllegalArgumentException("dstBuff must be a direct buffer");
}

Objects.checkFromIndexSize(srcOffset, srcSize, srcBuff.length);
Objects.checkFromIndexSize(dstOffset, dstSize, dstBuff.limit());

ensureOpen();
acquireSharedLock();

try {
long size = decompressByteArrayToDirectByteBuffer0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
if (Zstd.isError(size)) {
throw new ZstdException(size);
}
if (size > Integer.MAX_VALUE) {
throw new ZstdException(Zstd.errGeneric(), "Output size is greater than MAX_INT");
}
return (int) size;
} finally {
releaseSharedLock();
}
}

private static native long decompressByteArrayToDirectByteBuffer0(long nativePtr, ByteBuffer dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize);

public int decompressDirectByteBufferToByteArray(byte[] dstBuff, int dstOffset, int dstSize, ByteBuffer srcBuff, int srcOffset, int srcSize) {
if (!srcBuff.isDirect()) {
throw new IllegalArgumentException("srcBuff must be a direct buffer");
}

Objects.checkFromIndexSize(srcOffset, srcSize, srcBuff.limit());
Objects.checkFromIndexSize(dstOffset, dstSize, dstBuff.length);

ensureOpen();
acquireSharedLock();

try {
long size = decompressDirectByteBufferToByteArray0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
if (Zstd.isError(size)) {
throw new ZstdException(size);
}
if (size > Integer.MAX_VALUE) {
throw new ZstdException(Zstd.errGeneric(), "Output size is greater than MAX_INT");
}
return (int) size;
} finally {
releaseSharedLock();
}
}

private static native long decompressDirectByteBufferToByteArray0(long nativePtr, byte[] dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize);

/* Covenience methods */

/**
Expand Down Expand Up @@ -267,6 +321,28 @@ public int decompress(ByteBuffer dstBuf, ByteBuffer srcBuf) throws ZstdException
return size;
}

public int decompress(ByteBuffer dstBuf, byte[] src) throws ZstdException {
int size = decompressByteArrayToDirectByteBuffer(dstBuf, // decompress into dstBuf
dstBuf.position(), // write decompressed data at offset position()
dstBuf.limit() - dstBuf.position(), // write no more than limit() - position()
src, // read compressed data from src
0,
src.length);
dstBuf.position(dstBuf.position() + size);
return size;
}

public int decompress(byte[] dst, ByteBuffer srcBuf) throws ZstdException {
int size = decompressDirectByteBufferToByteArray(dst, // decompress into dst
0,
dst.length,
srcBuf, // read compressed data from srcBuf
srcBuf.position(), // read starting at offset position()
srcBuf.limit() - srcBuf.position()); // read no more than limit() - position()
srcBuf.position(srcBuf.limit());
return size;
}

public ByteBuffer decompress(ByteBuffer srcBuf, int originalSize) throws ZstdException {
ByteBuffer dstBuf = ByteBuffer.allocateDirect(originalSize);
int size = decompressDirectByteBuffer(dstBuf, 0, originalSize, srcBuf, srcBuf.position(), srcBuf.limit() - srcBuf.position());
Expand Down
66 changes: 66 additions & 0 deletions src/main/native/jni_fast_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -689,3 +689,69 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressB
E2: (*env)->ReleasePrimitiveArrayCritical(env, dst, dst_buff, 0);
E1: return size;
}

/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: decompressByteArrayToDirectByteBuffer0
* Signature: (Ljava/nio/ByteBuffer;II[BII)I
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressByteArrayToDirectByteBuffer0
(JNIEnv *env, jclass jclazz, jlong ptr, jobject dst, jint dst_offset, jint dst_size, jbyteArray src, jint src_offset, jint src_size) {
size_t size = -ZSTD_error_memory_allocation;

if (NULL == dst) return -ZSTD_error_dstSize_tooSmall;
if (NULL == src) return -ZSTD_error_srcSize_wrong;
if (0 > dst_offset) return -ZSTD_error_dstSize_tooSmall;
if (0 > src_offset) return -ZSTD_error_srcSize_wrong;
if (0 > src_size) return -ZSTD_error_srcSize_wrong;

if (src_offset + src_size > (*env)->GetArrayLength(env, src)) return -ZSTD_error_srcSize_wrong;
jsize dst_cap = (*env)->GetDirectBufferCapacity(env, dst);
if (dst_offset + dst_size > dst_cap) return -ZSTD_error_dstSize_tooSmall;

ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;

char *dst_buff = (char*)(*env)->GetDirectBufferAddress(env, dst);
if (dst_buff == NULL) return -ZSTD_error_memory_allocation;
void *src_buff = (*env)->GetPrimitiveArrayCritical(env, src, NULL);
if (src_buff == NULL) goto E1;

ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
size = ZSTD_decompressDCtx(dctx, ((char *)dst_buff) + dst_offset, (size_t) dst_size, ((char *)src_buff) + src_offset, (size_t) src_size);

(*env)->ReleasePrimitiveArrayCritical(env, src, src_buff, JNI_ABORT);
E1: return size;
}

/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: decompressDirectByteBufferToByteArray0
* Signature: ([BIILjava/nio/ByteBuffer;II)I
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressDirectByteBufferToByteArray0
(JNIEnv *env, jclass jclazz, jlong ptr, jbyteArray dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size) {
size_t size = -ZSTD_error_memory_allocation;

if (NULL == dst) return -ZSTD_error_dstSize_tooSmall;
if (NULL == src) return -ZSTD_error_srcSize_wrong;
if (0 > dst_offset) return -ZSTD_error_dstSize_tooSmall;
if (0 > src_offset) return -ZSTD_error_srcSize_wrong;
if (0 > src_size) return -ZSTD_error_srcSize_wrong;

if (dst_offset + dst_size > (*env)->GetArrayLength(env, dst)) return -ZSTD_error_dstSize_tooSmall;
jsize src_cap = (*env)->GetDirectBufferCapacity(env, src);
if (src_offset + src_size > src_cap) return -ZSTD_error_srcSize_wrong;

ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;

char *src_buff = (char*)(*env)->GetDirectBufferAddress(env, src);
if (src_buff == NULL) return -ZSTD_error_memory_allocation;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will leave the dst_buff behind critical lock that is never going to be released. I suggest adding E2 on the line before E1 and jump there.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option is to swap the GetPrimitiveArrayCritical and GetDirectBufferAddress sections

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fixed, thank you.

void *dst_buff = (*env)->GetPrimitiveArrayCritical(env, dst, NULL);
if (dst_buff == NULL) goto E1;

ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
size = ZSTD_decompressDCtx(dctx, ((char *)dst_buff) + dst_offset, (size_t) dst_size, ((char *)src_buff) + src_offset, (size_t) src_size);

(*env)->ReleasePrimitiveArrayCritical(env, dst, dst_buff, 0);
E1: return size;
}
18 changes: 7 additions & 11 deletions src/test/scala/Zstd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,10 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
val size = input.length
val compressed = Zstd.compress(input, level)

val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size.toLong).toInt)
compressedBuffer.put(compressed)
compressedBuffer.limit(compressedBuffer.position())
compressedBuffer.flip()

val decompressedBuffer = Zstd.decompress(compressedBuffer, size)
val decompressed = new Array[Byte](size)
val decompressedBuffer = ByteBuffer.allocateDirect(size)
val decompressedSize = Zstd.decompress(decompressedBuffer, compressed);
val decompressed = new Array[Byte](decompressedSize)
decompressedBuffer.flip();
decompressedBuffer.get(decompressed)
input.toSeq == decompressed.toSeq
}
Expand All @@ -104,11 +101,10 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
val inputBuffer = ByteBuffer.allocateDirect(size)
inputBuffer.put(input)
inputBuffer.flip()
val compressedBuffer = Zstd.compress(inputBuffer, level)
val compressed = new Array[Byte](compressedBuffer.limit() - compressedBuffer.position())
compressedBuffer.get(compressed)
val compressedBuffer = Zstd.compress(inputBuffer, level)

val decompressed = Zstd.decompress(compressed, size)
val decompressed = new Array[Byte](size)
Zstd.decompress(decompressed, compressedBuffer)
input.toSeq == decompressed.toSeq
}
}
Expand Down
Loading