Skip to content

Commit

Permalink
add IO MINISITer test
Browse files Browse the repository at this point in the history
  • Loading branch information
yanqingmen committed Dec 26, 2015
1 parent 949a611 commit 1bbde69
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
4 changes: 2 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import scala.collection.mutable.ListBuffer
object IO {
private val logger = LoggerFactory.getLogger(classOf[DataIter])
type IterCreateFunc = (Map[String, String])=>DataIter
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()
val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()

def _initIOModule(): Map[String, IterCreateFunc] = {
val IterCreators = new ListBuffer[DataIterCreator]
Expand All @@ -25,7 +25,7 @@ object IO {
checkCall(_LIB.mxDataIterGetIterInfo(handle, name, desc, argNames, argTypes, argDescs))
val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs)
val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n"
logger.debug(docStr)
// logger.debug(docStr)
return (name.value, creator(handle))
}

Expand Down
32 changes: 29 additions & 3 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,34 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}


class IOSuite extends FunSuite with BeforeAndAfterAll {
test("create iter funcs") {
val iterCreateFuncs: Map[String, IO.IterCreateFunc] = IO._initIOModule()
println(iterCreateFuncs.keys.toList)
test("test MNISTIter") {
val params = Map(
"image" -> "/home/hzx/workspace/git/mxnet-scala/mxnet/tests/python/common/data/train-images-idx3-ubyte",
"label" -> "/home/hzx/workspace/git/mxnet-scala/mxnet/tests/python/common/data/train-labels-idx1-ubyte",
"data_shape" -> "(784,)",
"batch_size" -> "100",
"shuffle" -> "1",
"flat" -> "1",
"silent" -> "0",
"seed" -> "10"
)
// println("create MNISTIter")
val mnist_iter = IO.iterCreateFuncs("MNISTIter")(params)
mnist_iter.reset()
mnist_iter.iterNext()
while(mnist_iter.iterNext()) {
val data = mnist_iter.getData()
val label = mnist_iter.getLabel()
val index = mnist_iter.getIndex()
val pad = mnist_iter.getPad()
// println("data: " + data.toArray.mkString(","))
// println("label: " + label.toArray.mkString(","))
// println("index: " + index)
// println("pad: " + pad)
}
}

test("test ImageRecordIter") {

}
}
5 changes: 3 additions & 2 deletions scala-package/native/linux-x86_64/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@
-fopenmp
</linkerMiddleOption>
<linkerMiddleOption>-Wl,--whole-archive</linkerMiddleOption>
<linkerMiddleOption>../../../lib/libmxnet.a</linkerMiddleOption>
<linkerMiddleOption>-Wl,-no-whole-archive</linkerMiddleOption>
<linkerMiddleOption>../../../lib/libmxnet.a ../../../dmlc-core/libdmlc.a
</linkerMiddleOption>
<linkerMiddleOption>-Wl,--no-whole-archive</linkerMiddleOption>
</linkerMiddleOptions>
</configuration>

Expand Down
23 changes: 20 additions & 3 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDateIterCreateIter
vals[i] = (char*)env->GetStringUTFChars(jval, 0);
}

// printf("paramSize: %d\n", paramSize);
// for(int i=0; i<paramSize; i++) {
// printf("key: %s\t",keys[i]);
// printf("value: %s\n", vals[i]);
// }

//create iter
jlong creatorPtr = getLongField(env, creator);
DataIterHandle out;
int ret = MXDataIterCreateIter((DataIterCreator)creator,
int ret = MXDataIterCreateIter((DataIterCreator)creatorPtr,
(mx_uint) paramSize,
(const char**) keys,
(const char**) vals,
Expand Down Expand Up @@ -517,7 +523,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetLabel
env->SetLongField(ndArrayHandle, refLongFid, (jlong)out);
return ret;
}

l
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetData
(JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) {
jlong handlePtr = getLongField(env, handle);
Expand All @@ -535,7 +541,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex
uint64_t* coutIndex;
uint64_t coutSize;
int ret = MXDataIterGetIndex((DataIterHandle)handlePtr, &coutIndex, &coutSize);
//to do
//set field
setLongField(env, outSize, (long)coutSize);
// scala.collection.mutable.ListBuffer append method
jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
jmethodID listAppend = env->GetMethodID(listClass,
"$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

printf("outSize: %ld\n", coutSize);
for(int i=0; i<coutSize; i++) {
printf("%ld\t", coutIndex[i]);
env->CallObjectMethod(outIndex, listAppend, (jlong)coutIndex[i]);
}
return ret;
}

Expand Down

0 comments on commit 1bbde69

Please sign in to comment.