Skip to content

Commit

Permalink
Merge pull request apache#20 from yanqingmen/scala
Browse files Browse the repository at this point in the history
IO handle modification (RefLong -> Long)
  • Loading branch information
yzhliu committed Jan 6, 2016
2 parents 722c7a8 + e0f1252 commit a0ad1e0
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 48 deletions.
6 changes: 4 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ object Base {
type SymbolHandle = CPtrAddress
type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress
type DataIterHandle = CPtrAddress
type DataIterCreator = CPtrAddress

type MXUintRef = RefInt
type MXFloatRef = RefFloat
type NDArrayHandleRef = RefLong
type FunctionHandleRef = RefLong
type DataIterHandle = RefLong
type DataIterCreator = RefLong
type DataIterHandleRef = RefLong
type DataIterCreatorRef = RefLong
type KVStoreHandle = RefLong
type ExecutorHandle = RefLong
type SymbolHandleRef = RefLong
Expand Down
4 changes: 2 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ object IO {
*/
private def creator(handle: DataIterCreator)(
params: Map[String, String]): DataIter = {
val out = new DataIterHandle
val out = new DataIterHandleRef
val keys = params.keys.toArray
val vals = params.values.toArray
checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out))
new MXDataIter(out)
new MXDataIter(out.value)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class LibInfo {
@native def mxDataIterCreateIter(handle: DataIterCreator,
keys: Array[String],
vals: Array[String],
out: DataIterHandle): Int
out: DataIterHandleRef): Int
@native def mxDataIterGetIterInfo(creator: DataIterCreator,
name: RefString,
description: RefString,
Expand Down
3 changes: 3 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
mnistIter.reset()
mnistIter.iterNext()
val label0 = mnistIter.getLabel().toArray
val data0 = mnistIter.getData().toArray
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.reset()
mnistIter.iterNext()
val label1 = mnistIter.getLabel().toArray
val data1 = mnistIter.getData().toArray
assert(label0 === label1)
assert(data0 === data1)
}


Expand Down
69 changes: 26 additions & 43 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,8 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env
//IO funcs
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters
(JNIEnv * env, jobject obj, jobject creators) {
// Base.FunctionHandle.constructor
jclass chClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jmethodID chConstructor = env->GetMethodID(chClass,"<init>","(J)V");
jclass longCls = env->FindClass("java/lang/Long");
jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");

// scala.collection.mutable.ListBuffer append method
jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
Expand All @@ -470,16 +469,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters
mx_uint outSize;
int ret = MXListDataIters(&outSize, &outArray);
for (int i = 0; i < outSize; ++i) {
DataIterCreator chAddr = outArray[i];
jobject chObj = env->NewObject(chClass, chConstructor, (long)chAddr);
env->CallObjectMethod(creators, listAppend, chObj);
env->CallObjectMethod(creators, listAppend,
env->NewObject(longCls, longConst, (long)outArray[i]));
}
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter
(JNIEnv * env, jobject obj, jobject creator,
jobjectArray jkeys, jobjectArray jvals, jobject dataIterHandle) {
(JNIEnv * env, jobject obj, jlong creator, jobjectArray jkeys,
jobjectArray jvals, jobject dataIterHandleRef) {
//keys and values
int paramSize = env->GetArrayLength(jkeys);
char** keys = new char*[paramSize];
Expand All @@ -501,16 +499,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter
}

//create iter
jlong creatorPtr = getLongField(env, creator);
DataIterHandle out;
int ret = MXDataIterCreateIter((DataIterCreator)creatorPtr,
int ret = MXDataIterCreateIter((DataIterCreator)creator,
(mx_uint) paramSize,
(const char**) keys,
(const char**) vals,
&out);
jclass hClass = env->GetObjectClass(dataIterHandle);
jfieldID ptr = env->GetFieldID(hClass, "value", "J");
env->SetLongField(dataIterHandle, ptr, (long)out);
setLongField(env, dataIterHandleRef, (long)out);

//release keys and vals
for(int i=0; i<paramSize; i++) {
Expand All @@ -524,16 +519,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIterInfo
(JNIEnv * env, jobject obj, jobject creator, jobject jname,
(JNIEnv * env, jobject obj, jlong creator, jobject jname,
jobject jdesc, jobject jargNames, jobject jargTypeInfos, jobject jargDescs) {
jlong creatorPtr = getLongField(env, creator);
const char* name;
const char* description;
mx_uint numArgs;
const char** argNames;
const char** argTypeInfos;
const char** argDescs;
int ret = MXDataIterGetIterInfo((DataIterCreator)creatorPtr,
int ret = MXDataIterGetIterInfo((DataIterCreator)creator,
&name,
&description,
&numArgs,
Expand All @@ -558,56 +552,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIterInfo
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterFree
(JNIEnv *env, jobject obj, jobject handle) {
jlong handlePtr = getLongField(env, handle);
int ret = MXDataIterFree((DataIterHandle) handlePtr);
(JNIEnv *env, jobject obj, jlong handle) {
int ret = MXDataIterFree((DataIterHandle) handle);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterBeforeFirst
(JNIEnv *env, jobject obj, jobject handle) {
jlong handlePtr = getLongField(env, handle);
int ret = MXDataIterBeforeFirst((DataIterHandle) handlePtr);
(JNIEnv *env, jobject obj, jlong handle) {
int ret = MXDataIterBeforeFirst((DataIterHandle) handle);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterNext
(JNIEnv *env, jobject obj, jobject handle, jobject out) {
jlong handlePtr = getLongField(env, handle);
(JNIEnv *env, jobject obj, jlong handle, jobject out) {
int cout;
int ret = MXDataIterNext((DataIterHandle)handlePtr, &cout);
int ret = MXDataIterNext((DataIterHandle)handle, &cout);
setIntField(env, out, cout);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetLabel
(JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) {
jlong handlePtr = getLongField(env, handle);
(JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
NDArrayHandle out;
int ret = MXDataIterGetLabel((DataIterHandle)handlePtr, &out);
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
env->SetLongField(ndArrayHandle, refLongFid, (jlong)out);
int ret = MXDataIterGetLabel((DataIterHandle)handle, &out);
setLongField(env, ndArrayHandleRef, (jlong)out);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetData
(JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) {
jlong handlePtr = getLongField(env, handle);
(JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
NDArrayHandle out;
int ret = MXDataIterGetData((DataIterHandle)handlePtr, &out);
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
env->SetLongField(ndArrayHandle, refLongFid, (jlong)out);
int ret = MXDataIterGetData((DataIterHandle)handle, &out);
setLongField(env, ndArrayHandleRef, (jlong)out);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex
(JNIEnv *env, jobject obj, jobject handle, jobject outIndex, jobject outSize) {
jlong handlePtr = getLongField(env, handle);
(JNIEnv *env, jobject obj, jlong handle, jobject outIndex, jobject outSize) {
uint64_t* coutIndex;
uint64_t coutSize;
int ret = MXDataIterGetIndex((DataIterHandle)handlePtr, &coutIndex, &coutSize);
int ret = MXDataIterGetIndex((DataIterHandle)handle, &coutIndex, &coutSize);
//set field
setLongField(env, outSize, (long)coutSize);
// scala.collection.mutable.ListBuffer append method
Expand All @@ -622,10 +606,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum
(JNIEnv *env, jobject obj, jobject handle, jobject pad) {
jlong handlePtr = getLongField(env, handle);
(JNIEnv *env, jobject obj, jlong handle, jobject pad) {
int cpad;
int ret = MXDataIterGetPadNum((DataIterHandle)handlePtr, &cpad);
int ret = MXDataIterGetPadNum((DataIterHandle)handle, &cpad);
setIntField(env, pad, cpad);
return ret;
}
Expand Down

0 comments on commit a0ad1e0

Please sign in to comment.