Skip to content

Commit

Permalink
Added support for initialising a dict from a direct ByteBuffer. This …
Browse files Browse the repository at this point in the history
…avoid the critical region otherwise required on the byte array.
  • Loading branch information
Morten Grouleff authored and luben committed Apr 14, 2023
1 parent 67743d3 commit e845db1
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 5 deletions.
19 changes: 19 additions & 0 deletions src/main/java/com/github/luben/zstd/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,25 @@ public static long decompressedDirectByteBufferSize(ByteBuffer src, int srcPosit
*/
public static native long getDictIdFromDict(byte[] dict);

private static native long getDictIdFromDictDirect(ByteBuffer dict, int offset, int length);

/**
* Get DictId of a dictionary
*
* @param dict dictionary as Direct ByteBuffer
* @return DictId or 0 if not available
*/
public static long getDictIdFromDictDirect(ByteBuffer dict) {
int length = dict.limit() - dict.position();
if (!dict.isDirect()) {
throw new IllegalArgumentException("dict must be a direct buffer");
}
if (length < 0) {
throw new IllegalArgumentException("dict cannot be empty.");
}
return getDictIdFromDictDirect(dict, dict.position(), length);
}

/* Stub methods for backward comatibility
*/

Expand Down
29 changes: 29 additions & 0 deletions src/main/java/com/github/luben/zstd/ZstdDictCompress.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.github.luben.zstd;

import java.nio.ByteBuffer;
import com.github.luben.zstd.util.Native;

public class ZstdDictCompress extends SharedDictBase {
Expand All @@ -13,6 +14,8 @@ public class ZstdDictCompress extends SharedDictBase {

private native void init(byte[] dict, int dict_offset, int dict_size, int level);

private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level);

private native void free();

/**
Expand Down Expand Up @@ -49,6 +52,32 @@ public ZstdDictCompress(byte[] dict, int offset, int length, int level) {
storeFence();
}

/**
* Create a new dictionary for use with fast compress. The provided bytebuffer is available for reuse when the method returns.
*
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
* @param level compression level
*/
public ZstdDictCompress(ByteBuffer dict, int level) {
this.level = level;
int length = dict.limit() - dict.position();
if (!dict.isDirect()) {
throw new IllegalArgumentException("dict must be a direct buffer");
}
if (length < 0) {
throw new IllegalArgumentException("dict cannot be empty.");
}
initDirect(dict, dict.position(), length, level);

if (nativePtr == 0L) {
throw new IllegalStateException("ZSTD_createCDict failed");
}
// Ensures that even if ZstdDictCompress is created and published through a race, no thread could observe
// nativePtr == 0.
storeFence();
}


int level() {
return level;
}
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/com/github/luben/zstd/ZstdDictDecompress.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.github.luben.zstd;

import java.nio.ByteBuffer;
import com.github.luben.zstd.util.Native;

public class ZstdDictDecompress extends SharedDictBase {
Expand All @@ -12,6 +13,8 @@ public class ZstdDictDecompress extends SharedDictBase {

private native void init(byte[] dict, int dict_offset, int dict_size);

private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size);

private native void free();

/**
Expand Down Expand Up @@ -43,6 +46,31 @@ public ZstdDictDecompress(byte[] dict, int offset, int length) {
}


/**
* Create a new dictionary for use with fast decompress. The provided bytebuffer is available for reuse when the method returns.
*
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
*/
public ZstdDictDecompress(ByteBuffer dict) {

int length = dict.limit() - dict.position();
if (!dict.isDirect()) {
throw new IllegalArgumentException("dict must be a direct buffer");
}
if (length < 0) {
throw new IllegalArgumentException("dict cannot be empty.");
}
initDirect(dict, dict.position(), length);

if (nativePtr == 0L) {
throw new IllegalStateException("ZSTD_createDDict failed");
}
// Ensures that even if ZstdDictDecompress is created and published through a race, no thread could observe
// nativePtr == 0.
storeFence();
}


@Override
void doClose() {
if (nativePtr != 0) {
Expand Down
41 changes: 39 additions & 2 deletions src/main/native/jni_fast_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ static jfieldID decompress_dict = 0;
/*
* Class: com_github_luben_zstd_ZstdDictCompress
* Method: init
* Signature: ([BI)V
* Signature: ([BIII)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_init
(JNIEnv *env, jobject obj, jbyteArray dict, jint dict_offset, jint dict_size, jint level)
Expand All @@ -29,6 +29,24 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_init
(*env)->SetLongField(env, obj, compress_dict, (jlong)(intptr_t) cdict);
}

/*
* Class: com_github_luben_zstd_ZstdDictCompress
* Method: init
* Signature: (Ljava/nio/ByteBuffer;III)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_initDirect
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level)
{
jclass clazz = (*env)->GetObjectClass(env, obj);
compress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
if (NULL == dict) return;
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);
if (NULL == dict_buff) return;
ZSTD_CDict* cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level);
if (NULL == cdict) return;
(*env)->SetLongField(env, obj, compress_dict, (jlong)(intptr_t) cdict);
}

/*
* Class: com_github_luben_zstd_ZstdDictCompress
* Method: free
Expand All @@ -46,7 +64,7 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_free
/*
* Class: com_github_luben_zstd_ZstdDictDecompress
* Method: init
* Signature: ([B)V
* Signature: ([BII)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_init
(JNIEnv *env, jobject obj, jbyteArray dict, jint dict_offset, jint dict_size)
Expand All @@ -64,6 +82,25 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_init
(*env)->SetLongField(env, obj, decompress_dict, (jlong)(intptr_t) ddict);
}

/*
* Class: com_github_luben_zstd_ZstdDictDecompress
* Method: initDirect
* Signature: (Ljava/nio/ByteBuffer;II)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_initDirect
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size)
{
jclass clazz = (*env)->GetObjectClass(env, obj);
decompress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
if (NULL == dict) return;
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);

ZSTD_DDict* ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size);

if (NULL == ddict) return;
(*env)->SetLongField(env, obj, decompress_dict, (jlong)(intptr_t) ddict);
}

/*
* Class: com_github_luben_zstd_ZstdDictDecompress
* Method: free
Expand Down
14 changes: 14 additions & 0 deletions src/main/native/jni_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getDictIdFromDict
E1: return (jlong) dict_id;
}

/*
* Class: com_github_luben_zstd_Zstd
* Method: getDictIdFromDict
* Signature: (Ljava/nio/ByteBuffer;II)J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getDictIdFromDictDirect
(JNIEnv *env, jclass obj, jobject src, jint offset, jint src_size) {
unsigned dict_id = 0;
char *src_buff = (char*)(*env)->GetDirectBufferAddress(env, src);
if (src_buff == NULL) goto E1;
dict_id = ZSTD_getDictID_fromDict(src_buff + offset, (size_t) src_size);
E1: return (jlong) dict_id;
}

/*
* Class: com_github_luben_zstd_Zstd
* Method: decompressedDirectByteBufferSize
Expand Down
18 changes: 15 additions & 3 deletions src/test/scala/ZstdDict.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ class ZstdDictSpec extends AnyFlatSpec {
dict
}

def wrapInDirectByteBuffer(array: Array[Byte]): ByteBuffer = {
//use a slightly oversized buffer and a nonzero offset to test we use the position/limit too
val bb = ByteBuffer.allocateDirect(array.length + 13);
bb.limit(bb.capacity())
bb.position(7)
bb.limit(bb.position() + array.length)
bb.put(array)
bb.flip()
bb
}

"Zstd" should "report error when failing to make a dict" in {
val src = source.sliding(28, 28).take(4).map(_.toArray)
val trainer = new ZstdDictTrainer(1024 * 1024, 32 * 1024)
Expand All @@ -42,6 +53,7 @@ class ZstdDictSpec extends AnyFlatSpec {
for {
legacy <- legacyS
dict = train(legacy)
dictInDirectByteBuffer = wrapInDirectByteBuffer(dict)
level <- levels
} {

Expand Down Expand Up @@ -99,19 +111,19 @@ class ZstdDictSpec extends AnyFlatSpec {
val inBuf = ByteBuffer.allocateDirect(size)
inBuf.put(input)
inBuf.flip()
val cdict = new ZstdDictCompress(dict, level)
val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level)
val compressed = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt);
Zstd.compress(compressed, inBuf, cdict)
compressed.flip()
cdict.close
val ddict = new ZstdDictDecompress(dict)
val ddict = new ZstdDictDecompress(dictInDirectByteBuffer)
val decompressed = ByteBuffer.allocateDirect(size)
Zstd.decompress(decompressed, compressed, ddict)
decompressed.flip()
ddict.close
val out = new Array[Byte](decompressed.remaining)
decompressed.get(out)
assert(Zstd.getDictIdFromFrameBuffer(compressed) == Zstd.getDictIdFromDict(dict))
assert(Zstd.getDictIdFromFrameBuffer(compressed) == Zstd.getDictIdFromDictDirect(dictInDirectByteBuffer))
assert(input.toSeq == out.toSeq)
}

Expand Down

0 comments on commit e845db1

Please sign in to comment.