From 9430f20f0ee941206194581cd28fc2a3f4691637 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 5 Jan 2016 16:20:32 +0800 Subject: [PATCH] NDArray handle ptr change to pure long --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 6 +- .../main/scala/ml/dmlc/mxnet/Executor.scala | 2 +- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 8 +- .../main/scala/ml/dmlc/mxnet/KVStore.scala | 6 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 34 ++---- .../main/scala/ml/dmlc/mxnet/NDArray.scala | 28 ++--- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 104 ++++++------------ 7 files changed, 71 insertions(+), 117 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 26026fc4d2eb..addd83b19f87 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -12,11 +12,13 @@ object Base { type CPtrAddress = Long type SymbolHandle = CPtrAddress + type NDArrayHandle = CPtrAddress + type FunctionHandle = CPtrAddress type MXUintRef = RefInt type MXFloatRef = RefFloat - type NDArrayHandle = RefLong - type FunctionHandle = RefLong + type NDArrayHandleRef = RefLong + type FunctionHandleRef = RefLong type DataIterHandle = RefLong type DataIterCreator = RefLong type KVStoreHandle = RefLong diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index b7f47f8e07b7..9ee252366c1a 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -135,7 +135,7 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { */ def backward(outGrads: Array[NDArray]): Unit = { require(outGrads != null) - val ndArrayPtrs = outGrads.map(_.handle.value) + val ndArrayPtrs = outGrads.map(_.handle) checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs)) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 2a706703aecd..82c7d48abb47 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -157,9 +157,9 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { * @return the data of current batch */ override def getData(): NDArray = { - val out = new NDArrayHandle + val out = new NDArrayHandleRef checkCall(_LIB.mxDataIterGetData(handle, out)) - new NDArray(out, writable = false) + new NDArray(out.value, writable = false) } /** @@ -167,9 +167,9 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { * @return the label of current batch */ override def getLabel(): NDArray = { - val out = new NDArrayHandle + val out = new NDArrayHandleRef checkCall(_LIB.mxDataIterGetLabel(handle, out)) - new NDArray(out, writable = false) + new NDArray(out.value, writable = false) } /** diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index 3882383f70da..4b787a6bb2f6 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -37,7 +37,7 @@ class KVStore(private val handle: KVStoreHandle) { */ def init(keys: Array[Int], values: Array[NDArray]): Unit = { require(keys.length == values.length, "len(keys) != len(values)") - val valuePtrs = values.map(_.handle.value) + val valuePtrs = values.map(_.handle) checkCall(_LIB.mxKVStoreInit(handle, keys.length, keys, valuePtrs)) } @@ -61,7 +61,7 @@ class KVStore(private val handle: KVStoreHandle) { */ def push(keys: Array[Int], values: Array[NDArray], priority: Int): Unit = { require(keys.length == values.length, "len(keys) != len(values)") - val valuePtrs = values.map(_.handle.value) + val valuePtrs = values.map(_.handle) checkCall(_LIB.mxKVStorePush(handle, keys.length, keys, valuePtrs, priority)) } @@ -97,7 +97,7 @@ class KVStore(private val handle: KVStoreHandle) { */ def pull(keys: Array[Int], outs: Array[NDArray], priority: Int): Unit = { require(keys.length == outs.length, "len(keys) != len(outs)") - val outPtrs = outs.map(_.handle.value) + val outPtrs = outs.map(_.handle) checkCall(_LIB.mxKVStorePull(handle, keys.length, keys, outPtrs, priority)) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 0837e6d1ca80..de3a45dfaad3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -11,13 +11,13 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} class LibInfo { @native def mxNDArrayFree(handle: NDArrayHandle): Int @native def mxGetLastError(): String - @native def mxNDArrayCreateNone(out: NDArrayHandle): Int + @native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int @native def mxNDArrayCreate(shape: Array[Int], ndim: Int, devType: Int, devId: Int, delayAlloc: Int, - out: NDArrayHandle): Int + out: NDArrayHandleRef): Int @native def mxNDArrayWaitAll(): Int @native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int @native def mxFuncDescribe(handle: FunctionHandle, @@ -33,13 +33,9 @@ class LibInfo { argTypes: ListBuffer[String], argDescs: ListBuffer[String]): Int @native def mxFuncInvoke(function: FunctionHandle, - // useVars ought to be Array[NDArrayHandle], - // we pass ptr address directly for performance consideration - useVars: Array[CPtrAddress], + useVars: Array[NDArrayHandle], scalarArgs: Array[MXFloat], - // mutateVars ought to be Array[NDArrayHandle], - // we pass ptr address directly for performance consideration - mutateVars: Array[CPtrAddress]): Int + mutateVars: Array[NDArrayHandle]): Int @native def mxNDArrayGetShape(handle: NDArrayHandle, ndim: MXUintRef, data: ArrayBuffer[Int]): Int @@ -49,7 +45,7 @@ class LibInfo { @native def mxNDArraySlice(handle: NDArrayHandle, start: MXUint, end: MXUint, - sliceHandle: NDArrayHandle): Int + sliceHandle: NDArrayHandleRef): Int @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle, source: Array[MXFloat], size: Int): Int @@ -57,22 +53,16 @@ class LibInfo { @native def mxKVStoreInit(handle: KVStoreHandle, len: MXUint, keys: Array[Int], - // values ought to be Array[NDArrayHandle], - // we pass ptr address directly for performance consideration - values: Array[CPtrAddress]): Int + values: Array[NDArrayHandle]): Int @native def mxKVStorePush(handle: KVStoreHandle, len: MXUint, keys: Array[Int], - // values ought to be Array[NDArrayHandle], - // we pass ptr address directly for performance consideration - values: Array[CPtrAddress], + values: Array[NDArrayHandle], priority: Int): Int @native def mxKVStorePull(handle: KVStoreHandle, len: MXUint, keys: Array[Int], - // outs ought to be Array[NDArrayHandle], - // we pass ptr address directly for performance consideration - outs: Array[CPtrAddress], + outs: Array[NDArrayHandle], priority: Int): Int @native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater, @@ -101,9 +91,9 @@ class LibInfo { @native def mxDataIterBeforeFirst(handle: DataIterHandle): Int @native def mxDataIterNext(handle: DataIterHandle, out: RefInt): Int @native def mxDataIterGetLabel(handle: DataIterHandle, - out: NDArrayHandle): Int + out: NDArrayHandleRef): Int @native def mxDataIterGetData(handle: DataIterHandle, - out: NDArrayHandle): Int + out: NDArrayHandleRef): Int @native def mxDataIterGetIndex(handle: DataIterHandle, outIndex: ListBuffer[Long], outSize: RefLong): Int @@ -115,9 +105,7 @@ class LibInfo { @native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int @native def mxExecutorBackward(handle: ExecutorHandle, gradsSize: Int, - // grads ought to be Array[NDArrayHandle], - // we pass ptr address directly for performance consideration - grads: Array[CPtrAddress]): Int + grads: Array[NDArrayHandle]): Int @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int @native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index cb09451a10da..83a45dded683 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -24,15 +24,15 @@ object NDArray { require(function != null, s"invalid function name $funcName") require(output == null || output.writable, "out must be writable") function match { - case BinaryNDArrayFunction(handle: NDArrayHandle, acceptEmptyMutate: Boolean) => + case BinaryNDArrayFunction(handle: FunctionHandle, acceptEmptyMutate: Boolean) => if (output == null) { require(acceptEmptyMutate, s"argument out is required to call $funcName") output = new NDArray(_newEmptyHandle()) } checkCall(_LIB.mxFuncInvoke(handle, - Array(lhs.handle.value, rhs.handle.value), + Array(lhs.handle, rhs.handle), Array[MXFloat](), - Array(output.handle.value))) + Array(output.handle))) case _ => throw new RuntimeException(s"call $funcName as binary function") } output @@ -53,9 +53,9 @@ object NDArray { output = new NDArray(_newEmptyHandle()) } checkCall(_LIB.mxFuncInvoke(handle, - Array(src.handle.value), + Array(src.handle), Array[MXFloat](), - Array(output.handle.value))) + Array(output.handle))) case _ => throw new RuntimeException(s"call $funcName as unary function") } output @@ -88,9 +88,9 @@ object NDArray { mutateVars = Array.fill[NDArray](nMutateVars)(new NDArray(_newEmptyHandle())) } checkCall(_LIB.mxFuncInvoke(handle, - useVarsRange.map(args(_).asInstanceOf[NDArray].handle.value).toArray, + useVarsRange.map(args(_).asInstanceOf[NDArray].handle).toArray, scalarRange.map(args(_).asInstanceOf[MXFloat]).toArray, - mutateVars.map(_.handle.value).array)) + mutateVars.map(_.handle).array)) case _ => throw new RuntimeException(s"call $funcName as generic function") } mutateVars @@ -103,9 +103,9 @@ object NDArray { * @return a new empty ndarray handle */ private def _newEmptyHandle(): NDArrayHandle = { - val hdl: NDArrayHandle = new NDArrayHandle + val hdl = new NDArrayHandleRef checkCall(_LIB.mxNDArrayCreateNone(hdl)) - hdl + hdl.value } /** @@ -117,7 +117,7 @@ object NDArray { private def _newAllocHandle(shape: Array[Int], ctx: Context, delayAlloc: Boolean): NDArrayHandle = { - val hdl = new NDArrayHandle + val hdl = new NDArrayHandleRef checkCall(_LIB.mxNDArrayCreate( shape, shape.length, @@ -125,7 +125,7 @@ object NDArray { ctx.deviceId, if (delayAlloc) 1 else 0, hdl)) - hdl + hdl.value } /** @@ -332,7 +332,7 @@ object NDArray { * NDArray is basic ndarray/Tensor like data structure in mxnet. */ // scalastyle:off finalize -class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { +class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true) { override def finalize(): Unit = { checkCall(_LIB.mxNDArrayFree(handle)) } @@ -356,9 +356,9 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { * @return a sliced NDArray that shares memory with current one. */ def slice(start: Int, stop: Int): NDArray = { - val sliceHandle = new NDArrayHandle() + val sliceHandle = new NDArrayHandleRef checkCall(_LIB.mxNDArraySlice(handle, start, stop, sliceHandle)) - new NDArray(handle = sliceHandle, writable = this.writable) + new NDArray(handle = sliceHandle.value, writable = this.writable) } def slice(start: Int): NDArray = { diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index f6950d097aa6..01e5e8c16691 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -9,7 +9,8 @@ #include "jni_helper_func.h" #include "ml_dmlc_mxnet_native_c_api.h" // generated by javah -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreateNone(JNIEnv *env, jobject obj, jobject ndArrayHandle) { +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreateNone + (JNIEnv *env, jobject obj, jobject ndArrayHandle) { NDArrayHandle out; int ret = MXNDArrayCreateNone(&out); jclass ndClass = env->GetObjectClass(ndArrayHandle); @@ -39,10 +40,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv *env, return MXNDArrayWaitAll(); } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListFunctions(JNIEnv *env, jobject obj, jobject functions) { - // Base.FunctionHandle.constructor - jclass fhClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jmethodID fhConstructor = env->GetMethodID(fhClass,"","(J)V"); +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListFunctions + (JNIEnv *env, jobject obj, jobject functions) { + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); // scala.collection.mutable.ListBuffer append method jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); @@ -54,21 +55,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListFunctions(JNIEnv *env, j mx_uint outSize; int ret = MXListFunctions(&outSize, &outArray); for (int i = 0; i < outSize; ++i) { - FunctionHandle fhAddr = outArray[i]; - jobject fhObj = env->NewObject(fhClass, fhConstructor, (long)fhAddr); - env->CallObjectMethod(functions, listAppend, fhObj); + env->CallObjectMethod(functions, listAppend, + env->NewObject(longCls, longConst, outArray[i])); } return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncDescribe(JNIEnv *env, jobject obj, - jobject funcHandle, - jobject nUsedVars, - jobject nScalars, - jobject nMutateVars, - jobject typeMask) { - jlong funcPtr = getLongField(env, funcHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncDescribe + (JNIEnv *env, jobject obj, jlong funcPtr, jobject nUsedVars, + jobject nScalars, jobject nMutateVars, jobject typeMask) { mx_uint numUseVars; mx_uint numScalars; mx_uint numMutateVars; @@ -86,16 +81,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncDescribe(JNIEnv *env, jo return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncGetInfo(JNIEnv *env, jobject obj, - jobject funcHandle, - jobject name, - jobject desc, - jobject numArgs, - jobject argNames, - jobject argTypes, - jobject argDescs) { - jlong funcPtr = getLongField(env, funcHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncGetInfo + (JNIEnv *env, jobject obj, jlong funcPtr, jobject name, + jobject desc, jobject numArgs, jobject argNames, jobject argTypes, jobject argDescs) { const char *cName; const char *cDesc; mx_uint cNumArgs; @@ -129,13 +117,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncGetInfo(JNIEnv *env, job return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncInvoke(JNIEnv *env, jobject obj, - jobject funcHandle, - jlongArray useVars, - jfloatArray scalarArgs, - jlongArray mutateVars) { - jlong funcPtr = getLongField(env, funcHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncInvoke + (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray useVars, + jfloatArray scalarArgs, jlongArray mutateVars) { jlong *cUseVars = env->GetLongArrayElements(useVars, NULL); jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL); jlong *cMutateVars = env->GetLongArrayElements(mutateVars, NULL); @@ -149,12 +133,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncInvoke(JNIEnv *env, jobj return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv *env, jobject obj, - jobject ndArrayHandle, - jobject ndimRef, - jobject dataBuf) { - jlong ndArrayPtr = getLongField(env, ndArrayHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape + (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject ndimRef, jobject dataBuf) { mx_uint ndim; const mx_uint *pdata; int ret = MXNDArrayGetShape((NDArrayHandle)ndArrayPtr, &ndim, &pdata); @@ -179,35 +159,24 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv *env, return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyToCPU(JNIEnv *env, jobject obj, - jobject ndArrayHandle, - jfloatArray data, - jint size) { - jlong ndArrayPtr = getLongField(env, ndArrayHandle); - +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyToCPU + (JNIEnv *env, jobject obj, jlong ndArrayPtr, jfloatArray data, jint size) { jfloat *pdata = env->GetFloatArrayElements(data, NULL); int ret = MXNDArraySyncCopyToCPU((NDArrayHandle)ndArrayPtr, (mx_float *)pdata, size); env->ReleaseFloatArrayElements(data, pdata, 0); // copy back to java array automatically return ret; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySlice(JNIEnv *env, jobject obj, - jobject ndArrayHandle, - jint start, - jint end, - jobject slicedHandle) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid); +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySlice + (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint start, jint end, jobject slicedHandle) { NDArrayHandle out; int ret = MXNDArraySlice((NDArrayHandle)ndArrayPtr, start, end, &out); - env->SetLongField(slicedHandle, refLongFid, (jlong)out); + setLongField(env, slicedHandle, (long)out); return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU - (JNIEnv *env, jobject obj, jobject ndArrayHandle, jfloatArray sourceArr, jint arrSize) { - jlong arrayPtr = getLongField(env, ndArrayHandle); + (JNIEnv *env, jobject obj, jlong arrayPtr, jfloatArray sourceArr, jint arrSize) { jfloat *sourcePtr = env->GetFloatArrayElements(sourceArr, NULL); int ret = MXNDArraySyncCopyFromCPU((NDArrayHandle)arrayPtr, (const mx_float *)sourcePtr, arrSize); env->ReleaseFloatArrayElements(sourceArr, sourcePtr, 0); @@ -215,8 +184,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree - (JNIEnv * env, jobject obj, jobject ndArrayHandle) { - return MXNDArrayFree((NDArrayHandle) getLongField(env, ndArrayHandle)); + (JNIEnv * env, jobject obj, jlong ndArrayHandle) { + return MXNDArrayFree((NDArrayHandle) ndArrayHandle); } // The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter, @@ -241,20 +210,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *en "update", "(ILml/dmlc/mxnet/NDArray;Lml/dmlc/mxnet/NDArray;Ljava/lang/Object;)V"); // find java NDArray constructor - jclass ndPtrClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jmethodID ndPtrConstructor = env->GetMethodID(ndPtrClass, "","(J)V"); jclass ndObjClass = env->FindClass("ml/dmlc/mxnet/NDArray"); - jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "", "(Lml/dmlc/mxnet/Base$RefLong;Z)V"); + jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "", "(JZ)V"); mxnet::NDArray *recvCopy = new mxnet::NDArray(); *recvCopy = recv; - jobject jNdRecvCopyPtr = env->NewObject(ndPtrClass, ndPtrConstructor, (long)recvCopy); - jobject jNdRecvCopy = env->NewObject(ndObjClass, ndObjConstructor, jNdRecvCopyPtr, true); + jobject jNdRecvCopy = env->NewObject(ndObjClass, ndObjConstructor, (jlong)recvCopy, true); mxnet::NDArray *localCopy = new mxnet::NDArray(); *localCopy = *local; - jobject jNdLocalCopyPtr = env->NewObject(ndPtrClass, ndPtrConstructor, (long)localCopy); - jobject jNdLocalCopy = env->NewObject(ndObjClass, ndObjConstructor, jNdLocalCopyPtr, true); + jobject jNdLocalCopy = env->NewObject(ndObjClass, ndObjConstructor, (jlong)localCopy, true); env->CallVoidMethod(updaterFuncObjGlb, updtFunc, key, jNdRecvCopy, jNdLocalCopy, updaterHandleGlb); env->DeleteGlobalRef(updaterFuncObjGlb); @@ -400,17 +365,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs NDArrayHandle *out; int ret = MXExecutorOutputs((ExecutorHandle)executorPtr, &outSize, &out); - // Base.ExecutorHandle.constructor - jclass ndArrayClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jmethodID ndArrayConstructor = env->GetMethodID(ndArrayClass,"","(J)V"); + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); // fill java outputs jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); jmethodID arrayAppend = env->GetMethodID(arrayClass, "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); for (int i = 0; i < outSize; ++i) { - jobject ndArray = env->NewObject(ndArrayClass, ndArrayConstructor, (long)out[i]); - env->CallObjectMethod(outputs, arrayAppend, ndArray); + env->CallObjectMethod(outputs, arrayAppend, + env->NewObject(longCls, longConst, out[i])); } return ret;