Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
yanqingmen committed Dec 29, 2015
2 parents 1f239ee + 19f445e commit 4710c7c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 33 deletions.
6 changes: 5 additions & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,13 @@ abstract class DataIter (val batchSize: Int = 0) {
* DataIter built in MXNet.
* @param handle the handle to the underlying C++ Data Iterator
*/
class MXDataIter(var handle: DataIterHandle) extends DataIter {
class MXDataIter(val handle: DataIterHandle) extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])

override def finalize() = {
checkCall(_LIB.mxDataIterFree(handle))
}

/**
* reset the iterator
*/
Expand Down
27 changes: 13 additions & 14 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,27 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
"seed" -> "10"
)

//println("create MNISTIter")
val mnist_iter = IO.createIterator("MNISTIter", params)
val mnistIter = IO.createIterator("MNISTIter", params)
//test_loop
mnist_iter.reset()
mnistIter.reset()
val nBatch = 600
var batchCount = 0
while(mnist_iter.iterNext()) {
val batch = mnist_iter.next()
while(mnistIter.iterNext()) {
val batch = mnistIter.next()
batchCount+=1
}
//test loop
assert(nBatch === batchCount)
//test reset
mnist_iter.reset()
mnist_iter.iterNext()
val label0 = mnist_iter.getLabel().toArray
mnist_iter.iterNext()
mnist_iter.iterNext()
mnist_iter.iterNext()
mnist_iter.reset()
mnist_iter.iterNext()
val label1 = mnist_iter.getLabel().toArray
mnistIter.reset()
mnistIter.iterNext()
val label0 = mnistIter.getLabel().toArray
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.reset()
mnistIter.iterNext()
val label1 = mnistIter.getLabel().toArray
assert(label0 === label1)
}

Expand Down
31 changes: 16 additions & 15 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,19 +522,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDateIterCreateIter
char** keys = new char*[paramSize];
char** vals = new char*[paramSize];
jstring jkey, jval;
//use strcpy and release char* created by JNI inplace
for(int i=0; i<paramSize; i++) {
jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
keys[i] = (char*)env->GetStringUTFChars(jkey, 0);
const char* ckey = env->GetStringUTFChars(jkey, 0);
keys[i] = new char[env->GetStringLength(jkey)];
strcpy(keys[i], ckey);
env->ReleaseStringUTFChars(jkey, ckey);

jval = (jstring) env->GetObjectArrayElement(jvals, i);
vals[i] = (char*)env->GetStringUTFChars(jval, 0);
const char* cval = env->GetStringUTFChars(jval, 0);
vals[i] = new char[env->GetStringLength(jval)];
strcpy(vals[i], cval);
env->ReleaseStringUTFChars(jval, cval);
}

// printf("paramSize: %d\n", paramSize);
// for(int i=0; i<paramSize; i++) {
// printf("key: %s\t",keys[i]);
// printf("value: %s\n", vals[i]);
// }

//create iter
jlong creatorPtr = getLongField(env, creator);
DataIterHandle out;
Expand All @@ -547,13 +549,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDateIterCreateIter
jfieldID ptr = env->GetFieldID(hClass, "value", "J");
env->SetLongField(dataIterHandle, ptr, (long)out);

//release const char*
//release keys and vals
for(int i=0; i<paramSize; i++) {
jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
env->ReleaseStringUTFChars(jkey,(const char*)keys[i]);
jval = (jstring) env->GetObjectArrayElement(jvals, i);
env->ReleaseStringUTFChars(jval,(const char*)vals[i]);
delete[] keys[i];
delete[] vals[i];
}
delete[] keys;
delete[] vals;

return ret;
}

Expand Down Expand Up @@ -649,9 +652,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex
jmethodID listAppend = env->GetMethodID(listClass,
"$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

// printf("outSize: %ld\n", coutSize);
for(int i=0; i<coutSize; i++) {
// printf("%ld\t", coutIndex[i]);
env->CallObjectMethod(outIndex, listAppend, (jlong)coutIndex[i]);
}
return ret;
Expand Down
6 changes: 3 additions & 3 deletions tests/travis/error_detector.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash
file=scala_test_results.txt

testFail=$(grep -ci "ERROR" $file)
if [ "$testFail" != "0" ]; then
testFail=$(grep -ci "All tests passed" $file)
if [ "$testFail" == "0" ]; then
cat $file
echo "Some unit tests failed. "
exit 1
else
echo "All unit tests passed! "
fi
fi

0 comments on commit 4710c7c

Please sign in to comment.