From ebba851d957d0cd853981b740816c447fef6985b Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 19 Dec 2015 23:54:58 +0800 Subject: [PATCH 1/2] KVStore setOptimizer, with JavaSerializer --- .../main/scala/ml/dmlc/mxnet/KVStore.scala | 41 +++++++++++++ .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 4 ++ .../main/scala/ml/dmlc/mxnet/Optimizer.scala | 27 ++++++++- .../main/scala/ml/dmlc/mxnet/Serializer.scala | 58 ++++++++++++++++++ .../native/src/main/native/jni_helper_func.h | 18 ++++++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 59 ++++++++++--------- 6 files changed, 176 insertions(+), 31 deletions(-) create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala create mode 100644 scala-package/native/src/main/native/jni_helper_func.h diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index f588aa20532b..cdb62e6d18a1 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -98,6 +98,32 @@ class KVStore(private val handle: KVStoreHandle) { pull(Array(key), Array(out), priority) } + // Get the type of this kvstore + def getType: String = { + val kvType = new RefString + checkCall(_LIB.mxKVStoreGetType(handle, kvType)) + kvType.value + } + + /** + * Register an optimizer to the store + * If there are multiple machines, this process (should be a worker node) + * will pack this optimizer and send it to all servers. It returns after + * this action is done. + * + * @param optimizer the optimizer + */ + def setOptimizer(optimizer: Optimizer): Unit = { + val isWorker = new RefInt + checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker)) + if ("dist" == getType && isWorker.value != 0) { + val optSerialized = Serializer.getSerializer.serialize(optimizer) + _sendCommandToServers(0, Serializer.encodeBase64String(optSerialized)) + } else { + setUpdater(Optimizer.getUpdater(optimizer)) + } + } + /** * Set a push updater into the store. * @@ -110,4 +136,19 @@ class KVStore(private val handle: KVStoreHandle) { this.updaterFunc = updater checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null)) } + + /** + * Send a command to all server nodes + * + * Send a command to all server nodes, which will make each server node run + * KVStoreServer.controller + * + * This function returns after the command has been executed in all server nodes + * + * @param head the head of the command + * @param body the body of the command + */ + private def _sendCommandToServers(head: Int, body: String): Unit = { + checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body)) + } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 18810b5440a5..c32877c571e2 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -71,4 +71,8 @@ class LibInfo { @native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater, updaterHandle: AnyRef): Int + @native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int + @native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int + @native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle, + head: Int, body: String): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala index 7e23639bdad1..14d5dbd1e117 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala @@ -1,7 +1,30 @@ package ml.dmlc.mxnet -class Optimizer { +object Optimizer { + def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = { + new MXKVStoreUpdater { + val states = new scala.collection.mutable.HashMap[Int, AnyRef] + override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = { + val state = states.getOrElseUpdate(index, optimizer.createState(index, weight)) + optimizer.update(index, weight, grad, state) + } + } + } +} + +abstract class Optimizer extends Serializable { + /** + * Update the parameters. + * @param index An unique integer key used to index the parameters + * @param weight weight ndarray + * @param grad grad ndarray + * @param state NDArray or other objects returned by initState + * The auxiliary state used in optimization. + */ + def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = ??? + // Create additional optimizer state such as momentum. + def createState(index: Int, weight: NDArray): AnyRef } trait MXKVStoreUpdater { @@ -13,5 +36,5 @@ trait MXKVStoreUpdater { * @param local the value stored on local on this key * @param handle The additional handle to the updater */ - def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef): Unit + def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef = null): Unit } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala new file mode 100644 index 000000000000..21d2410911a6 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala @@ -0,0 +1,58 @@ +package ml.dmlc.mxnet + +import java.io._ +import java.nio.ByteBuffer +import java.nio.charset.Charset +import java.util.Base64 + +import scala.reflect.ClassTag + +/** + * Serialize & deserialize Java/Scala [[Serializable]] objects + * @author Yizhi Liu + */ +abstract class Serializer { + def serialize[T: ClassTag](t: T): ByteBuffer + def deserialize[T: ClassTag](bytes: ByteBuffer): T +} + +object Serializer { + val UTF8 = Charset.forName("UTF-8") + + def getSerializer: Serializer = getSerializer(None) + + def getSerializer(serializer: Serializer): Serializer = { + // TODO: dynamically get from mxnet env to support other serializers like Kyro + if (serializer == null) new JavaSerializer else serializer + } + + def getSerializer(serializer: Option[Serializer]): Serializer = { + // TODO: dynamically get from mxnet env to support other serializers like Kyro + serializer.getOrElse(new JavaSerializer) + } + + def encodeBase64String(bytes: ByteBuffer): String = { + new String(Base64.getEncoder.encode(bytes).array, UTF8) + } + + def decodeBase64String(str: String): ByteBuffer = { + ByteBuffer.wrap(Base64.getDecoder.decode(str.getBytes(UTF8))) + } +} + +class JavaSerializer extends Serializer { + override def serialize[T: ClassTag](t: T): ByteBuffer = { + val bos = new ByteArrayOutputStream() + val out = new ObjectOutputStream(bos) + out.writeObject(t) + out.close() + ByteBuffer.wrap(bos.toByteArray) + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + val byteArray = bytes.array() + val bis = new ByteArrayInputStream(byteArray) + val in = new ObjectInputStream(bis) + in.readObject().asInstanceOf[T] + } +} diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h new file mode 100644 index 000000000000..864733d2be60 --- /dev/null +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -0,0 +1,18 @@ +#include + +#ifndef MXNET_SCALA_JNI_HELPER_FUNC_H +#define MXNET_SCALA_JNI_HELPER_FUNC_H + +jlong getLongField(JNIEnv *env, jobject obj) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refFid = env->GetFieldID(refClass, "value", "J"); + return env->GetLongField(obj, refFid); +} + +jint getIntField(JNIEnv *env, jobject obj) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); + jfieldID refFid = env->GetFieldID(refClass, "value", "I"); + return env->GetIntField(obj, refFid); +} + +#endif diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index f711dff91dcd..3883ab84a562 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -1,11 +1,11 @@ #include #include #include -#include #include #include #include +#include "jni_helper_func.h" #include "ml_dmlc_mxnet_native_c_api.h" // generated by javah JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreateNone(JNIEnv *env, jobject obj, jobject ndArrayHandle) { @@ -66,9 +66,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncDescribe(JNIEnv *env, jo jobject nScalars, jobject nMutateVars, jobject typeMask) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong funcPtr = env->GetLongField(funcHandle, refLongFid); + jlong funcPtr = getLongField(env, funcHandle); mx_uint numUseVars; mx_uint numScalars; @@ -95,9 +93,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncGetInfo(JNIEnv *env, job jobject argNames, jobject argTypes, jobject argDescs) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong funcPtr = env->GetLongField(funcHandle, refLongFid); + jlong funcPtr = getLongField(env, funcHandle); const char *cName; const char *cDesc; @@ -137,9 +133,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncInvoke(JNIEnv *env, jobj jlongArray useVars, jfloatArray scalarArgs, jlongArray mutateVars) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong funcPtr = env->GetLongField(funcHandle, refLongFid); + jlong funcPtr = getLongField(env, funcHandle); jlong *cUseVars = env->GetLongArrayElements(useVars, NULL); jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL); @@ -158,9 +152,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv *env, jobject ndArrayHandle, jobject ndimRef, jobject dataBuf) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid); + jlong ndArrayPtr = getLongField(env, ndArrayHandle); mx_uint ndim; const mx_uint *pdata; @@ -190,9 +182,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyToCPU(JNIEnv jobject ndArrayHandle, jfloatArray data, jint size) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid); + jlong ndArrayPtr = getLongField(env, ndArrayHandle); jfloat *pdata = env->GetFloatArrayElements(data, NULL); int ret = MXNDArraySyncCopyToCPU((NDArrayHandle)ndArrayPtr, (mx_float *)pdata, size); @@ -224,9 +214,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *en jobject updaterFuncObj, jobject updaterHandle) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj); jobject updaterHandleGlb = env->NewGlobalRef(updaterHandle); @@ -290,9 +278,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit(JNIEnv *env, job jintArray keys, jlongArray values) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *valueArray = env->GetLongArrayElements(values, NULL); @@ -312,9 +298,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush(JNIEnv *env, job jlongArray values, jint priority) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *valueArray = env->GetLongArrayElements(values, NULL); @@ -335,9 +319,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job jlongArray outs, jint priority) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *outArray = env->GetLongArrayElements(outs, NULL); @@ -351,7 +333,26 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job return ret; } -// TODO: move to c_api_error.c +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType + (JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject kvType) { + jlong kvStorePtr = getLongField(env, kvStoreHandle); + const char *type; + int ret = MXKVStoreGetType((KVStoreHandle)kvStorePtr, &type); + jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); + jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;"); + env->SetObjectField(kvType, valueStr, env->NewStringUTF(type)); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers + (JNIEnv *env, jobject obj, jobject kvStoreHandle, jint head, jstring body) { + jlong kvStorePtr = getLongField(env, kvStoreHandle); + const char *bodyCStr = env->GetStringUTFChars(body, 0); + int ret = MXKVStoreSendCommmandToServers((KVStoreHandle)kvStorePtr, head, bodyCStr); + env->ReleaseStringUTFChars(body, bodyCStr); + return ret; +} + JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { char *tmpstr = "MXNetError"; jstring rtstr = env->NewStringUTF(tmpstr); From 96cd86543f7e023189e872ee0728f1bfe2ecfdca Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 20 Dec 2015 11:29:42 +0800 Subject: [PATCH 2/2] add KVStore rank & numWorkers --- scala-package/core/pom.xml | 4 +++ .../main/scala/ml/dmlc/mxnet/KVStore.scala | 36 +++++++++++++++++-- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 3 ++ .../main/scala/ml/dmlc/mxnet/Optimizer.scala | 2 +- .../main/scala/ml/dmlc/mxnet/Serializer.scala | 7 ++-- .../scala/ml/dmlc/mxnet/KVStoreSuite.scala | 11 ++++++ .../native/src/main/native/jni_helper_func.h | 6 ++++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 24 +++++++++++++ scala-package/pom.xml | 5 +++ 9 files changed, 92 insertions(+), 6 deletions(-) diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 9f97aa27f99d..d0cc9082c5dd 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -61,6 +61,10 @@ org.scala-lang scala-library + + commons-codec + commons-codec + log4j log4j diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index cdb62e6d18a1..536123c57c52 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -99,12 +99,32 @@ class KVStore(private val handle: KVStoreHandle) { } // Get the type of this kvstore - def getType: String = { + def `type`: String = { val kvType = new RefString checkCall(_LIB.mxKVStoreGetType(handle, kvType)) kvType.value } + /** + * Get the number of worker nodes + * @return The number of worker nodes + */ + def numWorkers: Int = { + val size = new RefInt + checkCall(_LIB.mxKVStoreGetGroupSize(handle, size)) + size.value + } + + /** + * Get the rank of this worker node + * @return The rank of this node, which is in [0, get_num_workers()) + */ + def rank: Int = { + val rank = new RefInt + checkCall(_LIB.mxKVStoreGetRank(handle, rank)) + rank.value + } + /** * Register an optimizer to the store * If there are multiple machines, this process (should be a worker node) @@ -116,7 +136,7 @@ class KVStore(private val handle: KVStoreHandle) { def setOptimizer(optimizer: Optimizer): Unit = { val isWorker = new RefInt checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker)) - if ("dist" == getType && isWorker.value != 0) { + if ("dist" == `type` && isWorker.value != 0) { val optSerialized = Serializer.getSerializer.serialize(optimizer) _sendCommandToServers(0, Serializer.encodeBase64String(optSerialized)) } else { @@ -137,6 +157,18 @@ class KVStore(private val handle: KVStoreHandle) { checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null)) } + /** + * Global barrier among all worker nodes + * + * For example, assume there are n machines, we want to let machine 0 first + * init the values, and then pull the inited value to all machines. Before + * pulling, we can place a barrier to guarantee that the initialization is + * finished. + */ + def barrier() { + checkCall(_LIB.mxKVStoreBarrier(handle)) + } + /** * Send a command to all server nodes * diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index c32877c571e2..46dd88c062f6 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -75,4 +75,7 @@ class LibInfo { @native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int @native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle, head: Int, body: String): Int + @native def mxKVStoreBarrier(handle: KVStoreHandle): Int + @native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int + @native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala index 14d5dbd1e117..0d0cd38d6638 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala @@ -3,7 +3,7 @@ package ml.dmlc.mxnet object Optimizer { def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = { new MXKVStoreUpdater { - val states = new scala.collection.mutable.HashMap[Int, AnyRef] + private val states = new scala.collection.mutable.HashMap[Int, AnyRef] override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = { val state = states.getOrElseUpdate(index, optimizer.createState(index, weight)) optimizer.update(index, weight, grad, state) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala index 21d2410911a6..180f6354e010 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala @@ -3,7 +3,8 @@ package ml.dmlc.mxnet import java.io._ import java.nio.ByteBuffer import java.nio.charset.Charset -import java.util.Base64 + +import org.apache.commons.codec.binary.Base64 import scala.reflect.ClassTag @@ -32,11 +33,11 @@ object Serializer { } def encodeBase64String(bytes: ByteBuffer): String = { - new String(Base64.getEncoder.encode(bytes).array, UTF8) + new String(Base64.encodeBase64(bytes.array), UTF8) } def decodeBase64String(str: String): ByteBuffer = { - ByteBuffer.wrap(Base64.getDecoder.decode(str.getBytes(UTF8))) + ByteBuffer.wrap(Base64.decodeBase64(str.getBytes(UTF8))) } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala index 36c5899ab612..b237500281c8 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala @@ -45,4 +45,15 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { kv.pull(3, ndArray) assert(ndArray.toArray === Array(6f, 6f)) } + + test("get type") { + val kv = KVStore.create("local") + assert(kv.`type` === "local") + } + + test("get numWorkers and rank") { + val kv = KVStore.create("local") + assert(kv.numWorkers === 1) + assert(kv.rank === 0) + } } diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h index 864733d2be60..0d651adf9a94 100644 --- a/scala-package/native/src/main/native/jni_helper_func.h +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -15,4 +15,10 @@ jint getIntField(JNIEnv *env, jobject obj) { return env->GetIntField(obj, refFid); } +void setIntField(JNIEnv *env, jobject obj, jint value) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); + jfieldID refFid = env->GetFieldID(refClass, "value", "I"); + env->SetIntField(obj, refFid, value); +} + #endif diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 3883ab84a562..4e6b1461c945 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -353,6 +353,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreBarrier + (JNIEnv *env, jobject obj, jobject kvStoreHandle) { + jlong kvStorePtr = getLongField(env, kvStoreHandle); + return MXKVStoreBarrier((KVStoreHandle)kvStorePtr); +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetGroupSize + (JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject sizeRef) { + jlong kvStorePtr = getLongField(env, kvStoreHandle); + int size; + int ret = MXKVStoreGetGroupSize((KVStoreHandle)kvStorePtr, &size); + setIntField(env, sizeRef, size); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetRank + (JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject rankRef) { + jlong kvStorePtr = getLongField(env, kvStoreHandle); + int rank; + int ret = MXKVStoreGetRank((KVStoreHandle)kvStorePtr, &rank); + setIntField(env, rankRef, rank); + return ret; +} + JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { char *tmpstr = "MXNetError"; jstring rtstr = env->NewStringUTF(tmpstr); diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 3ad34f315893..85206d11409b 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -179,6 +179,11 @@ scala-library ${scala.version} + + commons-codec + commons-codec + 1.10 + log4j log4j