diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 3f23e1f31a42..9f02c6e4e2e0 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -11,11 +11,13 @@ object Base { type MXFloat = Float type CPtrAddress = Long - type SymbolHandle = CPtrAddress type NDArrayHandle = CPtrAddress type FunctionHandle = CPtrAddress type DataIterHandle = CPtrAddress type DataIterCreator = CPtrAddress + type KVStoreHandle = CPtrAddress + type ExecutorHandle = CPtrAddress + type SymbolHandle = CPtrAddress type MXUintRef = RefInt type MXFloatRef = RefFloat @@ -23,8 +25,8 @@ object Base { type FunctionHandleRef = RefLong type DataIterHandleRef = RefLong type DataIterCreatorRef = RefLong - type KVStoreHandle = RefLong - type ExecutorHandle = RefLong + type KVStoreHandleRef = RefLong + type ExecutorHandleRef = RefLong type SymbolHandleRef = RefLong System.loadLibrary("mxnet-scala") diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 9ee252366c1a..60513db13635 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -86,7 +86,7 @@ object Executor { * @see Symbol.bind : to create executor */ // scalastyle:off finalize -class Executor(val handle: ExecutorHandle, val symbol: Symbol) { +class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val symbol: Symbol) { var argArrays: Array[NDArray] = null protected var gradArrays: Array[NDArray] = null protected var auxArrays: Array[NDArray] = null @@ -136,7 +136,7 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { def backward(outGrads: Array[NDArray]): Unit = { require(outGrads != null) val ndArrayPtrs = outGrads.map(_.handle) - checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs)) + checkCall(_LIB.mxExecutorBackward(handle, ndArrayPtrs)) } def backward(outGrad: NDArray): Unit = { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index 4b787a6bb2f6..40b62b652990 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -17,9 +17,9 @@ object KVStore { * @return The created KVStore */ def create(name: String = "local"): KVStore = { - val handle = new KVStoreHandle + val handle = new KVStoreHandleRef checkCall(_LIB.mxKVStoreCreate(name, handle)) - new KVStore(handle) + new KVStore(handle.value) } } @@ -156,7 +156,7 @@ class KVStore(private val handle: KVStoreHandle) { checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker)) if ("dist" == `type` && isWorker.value != 0) { val optSerialized = Serializer.getSerializer.serialize(optimizer) - _sendCommandToServers(0, Serializer.encodeBase64String(optSerialized)) + sendCommandToServers(0, Serializer.encodeBase64String(optSerialized)) } else { setUpdater(Optimizer.getUpdater(optimizer)) } @@ -198,7 +198,7 @@ class KVStore(private val handle: KVStoreHandle) { * @param head the head of the command * @param body the body of the command */ - private def _sendCommandToServers(head: Int, body: String): Unit = { + private def sendCommandToServers(head: Int, body: String): Unit = { checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body)) } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index f2cc461b5ffa..8ca87dc2873f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -49,7 +49,7 @@ class LibInfo { @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle, source: Array[MXFloat], size: Int): Int - @native def mxKVStoreCreate(name: String, handle: KVStoreHandle): Int + @native def mxKVStoreCreate(name: String, handle: KVStoreHandleRef): Int @native def mxKVStoreInit(handle: KVStoreHandle, len: MXUint, keys: Array[Int], @@ -104,7 +104,6 @@ class LibInfo { @native def mxExecutorFree(handle: ExecutorHandle): Int @native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int @native def mxExecutorBackward(handle: ExecutorHandle, - gradsSize: Int, grads: Array[NDArrayHandle]): Int @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int @native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 50a83ce724a6..c844ef36b950 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -193,13 +193,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree // Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked, // which implies the function registered to KVStore must be stateful. // This is why we re-implement MXKVStoreSetUpdater as follows. -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *env, jobject obj, - jobject kvStoreHandle, - jobject updaterFuncObj, - jobject updaterHandle) { - // get kv store ptr - jlong kvStorePtr = getLongField(env, kvStoreHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater + (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj, jobject updaterHandle) { jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj); jobject updaterHandleGlb = env->NewGlobalRef(updaterHandle); std::function updt @@ -237,9 +232,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *en return 0; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate(JNIEnv *env, jobject obj, - jstring name, - jobject kvStoreHandle) { +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate + (JNIEnv *env, jobject obj, jstring name, jobject kvStoreHandle) { jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); @@ -252,14 +246,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate(JNIEnv *env, j return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit(JNIEnv *env, jobject obj, - jobject kvStoreHandle, - jint len, - jintArray keys, - jlongArray values) { - // get kv store ptr - jlong kvStorePtr = getLongField(env, kvStoreHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit + (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys, jlongArray values) { jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *valueArray = env->GetLongArrayElements(values, NULL); int ret = MXKVStoreInit((KVStoreHandle) kvStorePtr, @@ -271,15 +259,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit(JNIEnv *env, job return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush(JNIEnv *env, jobject obj, - jobject kvStoreHandle, - jint len, - jintArray keys, - jlongArray values, - jint priority) { - // get kv store ptr - jlong kvStorePtr = getLongField(env, kvStoreHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush + (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys, + jlongArray values, jint priority) { jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *valueArray = env->GetLongArrayElements(values, NULL); int ret = MXKVStorePush((KVStoreHandle)kvStorePtr, @@ -292,15 +274,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush(JNIEnv *env, job return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, jobject obj, - jobject kvStoreHandle, - jint len, - jintArray keys, - jlongArray outs, - jint priority) { - // get kv store ptr - jlong kvStorePtr = getLongField(env, kvStoreHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull + (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys, + jlongArray outs, jint priority) { jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *outArray = env->GetLongArrayElements(outs, NULL); int ret = MXKVStorePull((KVStoreHandle)kvStorePtr, @@ -314,8 +290,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType - (JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject kvType) { - jlong kvStorePtr = getLongField(env, kvStoreHandle); + (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject kvType) { const char *type; int ret = MXKVStoreGetType((KVStoreHandle)kvStorePtr, &type); jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); @@ -325,8 +300,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers - (JNIEnv *env, jobject obj, jobject kvStoreHandle, jint head, jstring body) { - jlong kvStorePtr = getLongField(env, kvStoreHandle); + (JNIEnv *env, jobject obj, jlong kvStorePtr, jint head, jstring body) { const char *bodyCStr = env->GetStringUTFChars(body, 0); int ret = MXKVStoreSendCommmandToServers((KVStoreHandle)kvStorePtr, head, bodyCStr); env->ReleaseStringUTFChars(body, bodyCStr); @@ -334,14 +308,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreBarrier - (JNIEnv *env, jobject obj, jobject kvStoreHandle) { - jlong kvStorePtr = getLongField(env, kvStoreHandle); + (JNIEnv *env, jobject obj, jlong kvStorePtr) { return MXKVStoreBarrier((KVStoreHandle)kvStorePtr); } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetGroupSize - (JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject sizeRef) { - jlong kvStorePtr = getLongField(env, kvStoreHandle); + (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject sizeRef) { int size; int ret = MXKVStoreGetGroupSize((KVStoreHandle)kvStorePtr, &size); setIntField(env, sizeRef, size); @@ -349,8 +321,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetGroupSize } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetRank - (JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject rankRef) { - jlong kvStorePtr = getLongField(env, kvStoreHandle); + (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject rankRef) { int rank; int ret = MXKVStoreGetRank((KVStoreHandle)kvStorePtr, &rank); setIntField(env, rankRef, rank); @@ -358,9 +329,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetRank } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs - (JNIEnv *env, jobject obj, jobject executorHandle, jobject outputs) { - - jlong executorPtr = getLongField(env, executorHandle); + (JNIEnv *env, jobject obj, jlong executorPtr, jobject outputs) { mx_uint outSize; NDArrayHandle *out; int ret = MXExecutorOutputs((ExecutorHandle)executorPtr, &outSize, &out); @@ -381,20 +350,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorFree - (JNIEnv * env, jobject obj, jobject handle) { - jlong ptr = getLongField(env, handle); + (JNIEnv * env, jobject obj, jlong ptr) { return MXExecutorFree((ExecutorHandle) ptr); } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorForward - (JNIEnv * env, jobject obj, jobject handle, jint isTrain) { - jlong ptr = getLongField(env, handle); + (JNIEnv * env, jobject obj, jlong ptr, jint isTrain) { return MXExecutorForward((ExecutorHandle)ptr, (int)isTrain); } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBackward - (JNIEnv * env, jobject obj, jobject handle, jint gradsSize, jlongArray grads) { - jlong executorPtr = getLongField(env, handle); + (JNIEnv * env, jobject obj, jlong executorPtr, jlongArray grads) { + int gradsSize = env->GetArrayLength(grads); jlong *gradArr = env->GetLongArrayElements(grads, NULL); int ret = MXExecutorBackward((ExecutorHandle)executorPtr, (mx_uint)gradsSize, @@ -404,17 +371,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBackward } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorPrint - (JNIEnv * env, jobject obj, jobject handle, jobject debugStr) { - jlong ptr = getLongField(env, handle); + (JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) { const char *retDebugStr; - int ret = MXExecutorPrint((ExecutorHandle)handle, &retDebugStr); + int ret = MXExecutorPrint((ExecutorHandle)ptr, &retDebugStr); setStringField(env, debugStr, retDebugStr); return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorSetMonitorCallback - (JNIEnv *env, jobject obj, jobject handle, jobject callbackFuncObj) { - jlong executorPtr = getLongField(env, handle); + (JNIEnv *env, jobject obj, jlong executorPtr, jobject callbackFuncObj) { jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj); std::function callback = [env, callbackFuncObjGlb](const char *name, NDArrayHandle array) {