diff --git a/scala-package/core/scripts/get_cifar_data.sh b/scala-package/core/scripts/get_cifar_data.sh new file mode 100755 index 000000000000..48c4bfde2225 --- /dev/null +++ b/scala-package/core/scripts/get_cifar_data.sh @@ -0,0 +1,11 @@ +data_path="./data" +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +cifar_data_path="./data/cifar10.zip" +if [ ! -f "$cifar_data_path" ]; then + wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P $data_path + cd $data_path + unzip -u cifar10.zip +fi \ No newline at end of file diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 6cd073461250..3fe00f7043e6 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -16,6 +16,8 @@ object IO { private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter") + def ImageRecordIter: IterCreateFunc = iterCreateFuncs("ImageRecordIter") + def CSVIter: IterCreateFunc = iterCreateFuncs("CSVIter") /** * create iterator via iterName and params @@ -66,7 +68,7 @@ object IO { } // Convert data into canonical form. - private def initData(data: NDArray, allowEmpty: Boolean, defaultName: String) = { + private def initData(data: List[NDArray], allowEmpty: Boolean, defaultName: String) = { require(data != null || allowEmpty) // TODO } @@ -145,9 +147,17 @@ abstract class DataIter(val batchSize: Int = 0) { * @param handle the handle to the underlying C++ Data Iterator */ // scalastyle:off finalize -class MXDataIter(val handle: DataIterHandle) extends DataIter { +class MXDataIter(val handle: DataIterHandle, + val dataName: String = "data", + val labelName: String = "label") extends DataIter { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) + // get first batch + reset() + iterNext() + val firstBatch = next() + reset() + override def finalize(): Unit = { checkCall(_LIB.mxDataIterFree(handle)) } @@ -212,17 +222,28 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { } // The name and shape of data provided by this iterator - override def provideData: Map[String, Shape] = ??? + override def provideData: Map[String, Shape] = { + Map(dataName -> firstBatch.data.head.shape) + } // The name and shape of label provided by this iterator - override def provideLabel: Map[String, Shape] = ??? + override def provideLabel: Map[String, Shape] = { + Map(labelName -> firstBatch.label.head.shape) + } } // scalastyle:on finalize /** - * TODO - */ -class ArrayDataIter() extends DataIter { + * Base class for prefetching iterators. Takes one or more DataIters + * (or any class with "reset" and "read" methods) and combine them with + * prefetching. + * @param iters list of DataIters + * @param dataNames + * @param labelNames + */ +class PrefetchingIter(val iters: List[DataIter], + val dataNames: Map[String, String] = null, + val labelNames: Map[String, String] = null) extends DataIter { /** * reset the iterator */ diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 6b45d9931ad9..ae184c92177f 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -21,13 +21,18 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { ) val mnistIter = IO.MNISTIter(params) + // test provideData + val provideData = mnistIter.provideData + val provideLabel = mnistIter.provideLabel + assert(provideData("data") === Array(100, 784)) + assert(provideLabel("label") === Array(100)) // test_loop mnistIter.reset() val nBatch = 600 var batchCount = 0 while(mnistIter.iterNext()) { val batch = mnistIter.next() - batchCount+=1 + batchCount += 1 } // test loop assert(nBatch === batchCount) @@ -49,11 +54,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { /** - * not work now + * default skip this test for saving time */ // test("test ImageRecordIter") { -// //get data -// //"./scripts/get_cifar_data.sh" ! +// // get data +// "./scripts/get_cifar_data.sh" ! // // val params = Map( // "path_imgrec" -> "data/cifar/train.rec", @@ -66,11 +71,36 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // "preprocess_threads" -> "4", // "prefetch_buffer" -> "1" // ) -// val img_iter = IO.createIterator("ImageRecordIter", params) -// img_iter.reset() -// while(img_iter.iterNext()) { -// val batch = img_iter.next() +// val imgRecIter = IO.ImageRecordIter(params) +// val nBatch = 500 +// var batchCount = 0 +// // test provideData +// val provideData = imgRecIter.provideData +// val provideLabel = imgRecIter.provideLabel +// assert(provideData("data") === Array(100, 3, 28, 28)) +// assert(provideLabel("label") === Array(100)) +// +// imgRecIter.reset() +// while(imgRecIter.iterNext()) { +// val batch = imgRecIter.next() +// batchCount += 1 // } +// // test loop +// assert(batchCount === nBatch) +// // test reset +// imgRecIter.reset() +// imgRecIter.iterNext() +// val label0 = imgRecIter.getLabel().head.toArray +// val data0 = imgRecIter.getData().head.toArray +// imgRecIter.iterNext() +// imgRecIter.iterNext() +// imgRecIter.iterNext() +// imgRecIter.reset() +// imgRecIter.iterNext() +// val label1 = imgRecIter.getLabel().head.toArray +// val data1 = imgRecIter.getData().head.toArray +// assert(label0 === label1) +// assert(data0 === data1) // } // test("test NDarryIter") { diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 052ef43e0fa9..a30215b067a0 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -648,8 +648,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + //long class + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); + for(int i=0; iCallObjectMethod(outIndex, listAppend, (jlong)coutIndex[i]); + env->CallObjectMethod(outIndex, listAppend, + env->NewObject(longCls, longConst, coutIndex[i])); } return ret; }