diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index addd83b19f87..3f23e1f31a42 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -14,13 +14,15 @@ object Base { type SymbolHandle = CPtrAddress type NDArrayHandle = CPtrAddress type FunctionHandle = CPtrAddress + type DataIterHandle = CPtrAddress + type DataIterCreator = CPtrAddress type MXUintRef = RefInt type MXFloatRef = RefFloat type NDArrayHandleRef = RefLong type FunctionHandleRef = RefLong - type DataIterHandle = RefLong - type DataIterCreator = RefLong + type DataIterHandleRef = RefLong + type DataIterCreatorRef = RefLong type KVStoreHandle = RefLong type ExecutorHandle = RefLong type SymbolHandleRef = RefLong diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 82c7d48abb47..3b6cac583f75 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -52,11 +52,11 @@ object IO { */ private def creator(handle: DataIterCreator)( params: Map[String, String]): DataIter = { - val out = new DataIterHandle + val out = new DataIterHandleRef val keys = params.keys.toArray val vals = params.values.toArray checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) - new MXDataIter(out) + new MXDataIter(out.value) } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index de3a45dfaad3..f2cc461b5ffa 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -80,7 +80,7 @@ class LibInfo { @native def mxDataIterCreateIter(handle: DataIterCreator, keys: Array[String], vals: Array[String], - out: DataIterHandle): Int + out: DataIterHandleRef): Int @native def mxDataIterGetIterInfo(creator: DataIterCreator, name: RefString, description: RefString, diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 2991743da008..fc215d67662a 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -35,13 +35,16 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { mnistIter.reset() mnistIter.iterNext() val label0 = mnistIter.getLabel().toArray + val data0 = mnistIter.getData().toArray mnistIter.iterNext() mnistIter.iterNext() mnistIter.iterNext() mnistIter.reset() mnistIter.iterNext() val label1 = mnistIter.getLabel().toArray + val data1 = mnistIter.getData().toArray assert(label0 === label1) + assert(data0 === data1) } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 01e5e8c16691..50a83ce724a6 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -456,9 +456,8 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env //IO funcs JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters (JNIEnv * env, jobject obj, jobject creators) { - // Base.FunctionHandle.constructor - jclass chClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jmethodID chConstructor = env->GetMethodID(chClass,"","(J)V"); + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); // scala.collection.mutable.ListBuffer append method jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); @@ -470,16 +469,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters mx_uint outSize; int ret = MXListDataIters(&outSize, &outArray); for (int i = 0; i < outSize; ++i) { - DataIterCreator chAddr = outArray[i]; - jobject chObj = env->NewObject(chClass, chConstructor, (long)chAddr); - env->CallObjectMethod(creators, listAppend, chObj); + env->CallObjectMethod(creators, listAppend, + env->NewObject(longCls, longConst, (long)outArray[i])); } return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter - (JNIEnv * env, jobject obj, jobject creator, - jobjectArray jkeys, jobjectArray jvals, jobject dataIterHandle) { + (JNIEnv * env, jobject obj, jlong creator, jobjectArray jkeys, + jobjectArray jvals, jobject dataIterHandleRef) { //keys and values int paramSize = env->GetArrayLength(jkeys); char** keys = new char*[paramSize]; @@ -501,16 +499,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter } //create iter - jlong creatorPtr = getLongField(env, creator); DataIterHandle out; - int ret = MXDataIterCreateIter((DataIterCreator)creatorPtr, + int ret = MXDataIterCreateIter((DataIterCreator)creator, (mx_uint) paramSize, (const char**) keys, (const char**) vals, &out); - jclass hClass = env->GetObjectClass(dataIterHandle); - jfieldID ptr = env->GetFieldID(hClass, "value", "J"); - env->SetLongField(dataIterHandle, ptr, (long)out); + setLongField(env, dataIterHandleRef, (long)out); //release keys and vals for(int i=0; iFindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - env->SetLongField(ndArrayHandle, refLongFid, (jlong)out); + int ret = MXDataIterGetLabel((DataIterHandle)handle, &out); + setLongField(env, ndArrayHandleRef, (jlong)out); return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetData - (JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) { - jlong handlePtr = getLongField(env, handle); + (JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) { NDArrayHandle out; - int ret = MXDataIterGetData((DataIterHandle)handlePtr, &out); - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - env->SetLongField(ndArrayHandle, refLongFid, (jlong)out); + int ret = MXDataIterGetData((DataIterHandle)handle, &out); + setLongField(env, ndArrayHandleRef, (jlong)out); return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex - (JNIEnv *env, jobject obj, jobject handle, jobject outIndex, jobject outSize) { - jlong handlePtr = getLongField(env, handle); + (JNIEnv *env, jobject obj, jlong handle, jobject outIndex, jobject outSize) { uint64_t* coutIndex; uint64_t coutSize; - int ret = MXDataIterGetIndex((DataIterHandle)handlePtr, &coutIndex, &coutSize); + int ret = MXDataIterGetIndex((DataIterHandle)handle, &coutIndex, &coutSize); //set field setLongField(env, outSize, (long)coutSize); // scala.collection.mutable.ListBuffer append method @@ -622,10 +606,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum - (JNIEnv *env, jobject obj, jobject handle, jobject pad) { - jlong handlePtr = getLongField(env, handle); + (JNIEnv *env, jobject obj, jlong handle, jobject pad) { int cpad; - int ret = MXDataIterGetPadNum((DataIterHandle)handlePtr, &cpad); + int ret = MXDataIterGetPadNum((DataIterHandle)handle, &cpad); setIntField(env, pad, cpad); return ret; }