From 1b984278ce8c63908c20c09c8fe09bc52982881d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AD=90=E8=BD=A9?= Date: Thu, 14 Jan 2016 01:09:17 -0800 Subject: [PATCH 1/8] add script for download test data --- scala-package/core/scripts/get_cifar_data.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100755 scala-package/core/scripts/get_cifar_data.sh diff --git a/scala-package/core/scripts/get_cifar_data.sh b/scala-package/core/scripts/get_cifar_data.sh new file mode 100755 index 000000000000..48c4bfde2225 --- /dev/null +++ b/scala-package/core/scripts/get_cifar_data.sh @@ -0,0 +1,11 @@ +data_path="./data" +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +cifar_data_path="./data/cifar10.zip" +if [ ! -f "$cifar_data_path" ]; then + wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P $data_path + cd $data_path + unzip -u cifar10.zip +fi \ No newline at end of file From f9b8e3e2f7a21ccc48a4e801b655624de992a1e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AD=90=E8=BD=A9?= Date: Fri, 15 Jan 2016 01:49:53 -0800 Subject: [PATCH 2/8] fix jni typo --- .../test/scala/ml/dmlc/mxnet/IOSuite.scala | 44 +++++++++---------- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 7 ++- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index fc215d67662a..27be2eb228dc 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -49,29 +49,29 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { /** - * not work now + * for time */ -// test("test ImageRecordIter") { -// //get data -// //"./scripts/get_cifar_data.sh" ! -// -// val params = Map( -// "path_imgrec" -> "data/cifar/train.rec", -// "mean_img" -> "data/cifar/cifar10_mean.bin", -// "rand_crop" -> "False", -// "and_mirror" -> "False", -// "shuffle" -> "False", -// "data_shape" -> "(3,28,28)", -// "batch_size" -> "100", -// "preprocess_threads" -> "4", -// "prefetch_buffer" -> "1" -// ) -// val img_iter = IO.createIterator("ImageRecordIter", params) -// img_iter.reset() -// while(img_iter.iterNext()) { -// val batch = img_iter.next() -// } -// } + test("test ImageRecordIter") { + //get data + "./scripts/get_cifar_data.sh" ! + + val params = Map( + "path_imgrec" -> "data/cifar/train.rec", + "mean_img" -> "data/cifar/cifar10_mean.bin", + "rand_crop" -> "False", + "and_mirror" -> "False", + "shuffle" -> "False", + "data_shape" -> "(3,28,28)", + "batch_size" -> "100", + "preprocess_threads" -> "4", + "prefetch_buffer" -> "1" + ) + val img_iter = IO.createIterator("ImageRecordIter", params) + img_iter.reset() + while(img_iter.iterNext()) { + val batch = img_iter.next() + } + } // test("test NDarryIter") { // diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index cfc419f29812..a37d8a699409 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -648,8 +648,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + //long class + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); + for(int i=0; iCallObjectMethod(outIndex, listAppend, (jlong)coutIndex[i]); + env->CallObjectMethod(outIndex, listAppend, + env->NewObject(longCls, longConst, coutIndex[i])); } return ret; } From 461e534ccde86e88019436b5bde80c0f61ebc4dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AD=90=E8=BD=A9?= Date: Sun, 24 Jan 2016 21:37:16 -0800 Subject: [PATCH 3/8] small change --- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 1 + .../test/scala/ml/dmlc/mxnet/IOSuite.scala | 44 +++++++++---------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 6cd073461250..ff4d09b82595 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -16,6 +16,7 @@ object IO { private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter") + def ImageRecordIter: IterCreateFunc = iterCreateFuncs("ImageRecordIter") /** * create iterator via iterName and params diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 87717977e284..9af09e913ac3 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -49,29 +49,29 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { /** - * for time + * disable this test for saving time */ - test("test ImageRecordIter") { - //get data - "./scripts/get_cifar_data.sh" ! - - val params = Map( - "path_imgrec" -> "data/cifar/train.rec", - "mean_img" -> "data/cifar/cifar10_mean.bin", - "rand_crop" -> "False", - "and_mirror" -> "False", - "shuffle" -> "False", - "data_shape" -> "(3,28,28)", - "batch_size" -> "100", - "preprocess_threads" -> "4", - "prefetch_buffer" -> "1" - ) - val img_iter = IO.createIterator("ImageRecordIter", params) - img_iter.reset() - while(img_iter.iterNext()) { - val batch = img_iter.next() - } - } +// test("test ImageRecordIter") { +// //get data +// "./scripts/get_cifar_data.sh" ! +// +// val params = Map( +// "path_imgrec" -> "data/cifar/train.rec", +// "mean_img" -> "data/cifar/cifar10_mean.bin", +// "rand_crop" -> "False", +// "and_mirror" -> "False", +// "shuffle" -> "False", +// "data_shape" -> "(3,28,28)", +// "batch_size" -> "100", +// "preprocess_threads" -> "4", +// "prefetch_buffer" -> "1" +// ) +// val imgIter = IO.ImageRecordIter(params); +// imgIter.reset() +// while(imgIter.iterNext()) { +// val batch = imgIter.next() +// } +// } // test("test NDarryIter") { // From 70c8e2c7831482899d308b256651aa3274e43530 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AD=90=E8=BD=A9?= Date: Mon, 25 Jan 2016 01:34:59 -0800 Subject: [PATCH 4/8] small change --- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index ff4d09b82595..86d90df6790d 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -146,9 +146,17 @@ abstract class DataIter(val batchSize: Int = 0) { * @param handle the handle to the underlying C++ Data Iterator */ // scalastyle:off finalize -class MXDataIter(val handle: DataIterHandle) extends DataIter { +class MXDataIter(val handle: DataIterHandle, + val dataName: String = "data", + val labelName: String = "label") extends DataIter { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) + //get first batch + reset() + iterNext() + val firstBatch = next() + reset() + override def finalize(): Unit = { checkCall(_LIB.mxDataIterFree(handle)) } @@ -221,9 +229,16 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { // scalastyle:on finalize /** - * TODO - */ -class ArrayDataIter() extends DataIter { + * Base class for prefetching iterators. Takes one or more DataIters + * (or any class with "reset" and "read" methods) and combine them with + * prefetching. + * @param iters list of DataIters + * @param dataNames + * @param labelNames + */ +class PrefetchingIter(val iters: List[DataIter], + val dataNames: Map[String, String] = null, + val labelNames: Map[String, String] = null) extends DataIter { /** * reset the iterator */ From e53c3eaed3561aba688dc69d7d12569f77dd1fba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AD=90=E8=BD=A9?= Date: Tue, 26 Jan 2016 01:28:59 -0800 Subject: [PATCH 5/8] add provideData provideLabel for MXDataIter --- .../core/src/main/scala/ml/dmlc/mxnet/IO.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 86d90df6790d..54a7983d580d 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -67,9 +67,9 @@ object IO { } // Convert data into canonical form. - private def initData(data: NDArray, allowEmpty: Boolean, defaultName: String) = { + private def initData(data: List[NDArray], allowEmpty: Boolean, defaultName: String) = { require(data != null || allowEmpty) - // TODO + //TODO } } @@ -221,10 +221,14 @@ class MXDataIter(val handle: DataIterHandle, } // The name and shape of data provided by this iterator - override def provideData: Map[String, Shape] = ??? + override def provideData: Map[String, Shape] = { + Map(dataName -> firstBatch.data.head.shape) + } // The name and shape of label provided by this iterator - override def provideLabel: Map[String, Shape] = ??? + override def provideLabel: Map[String, Shape] = { + Map(labelName -> firstBatch.label.head.shape) + } } // scalastyle:on finalize From 0c523221313eab667ce7adc5e75780376f1a94b5 Mon Sep 17 00:00:00 2001 From: yanqingmen Date: Wed, 27 Jan 2016 21:08:52 +0800 Subject: [PATCH 6/8] add ImageRecorderIter --- .../test/scala/ml/dmlc/mxnet/IOSuite.scala | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 27be2eb228dc..554c0e583188 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -27,7 +27,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { var batchCount = 0 while(mnistIter.iterNext()) { val batch = mnistIter.next() - batchCount+=1 + batchCount += 1 } // test loop assert(nBatch === batchCount) @@ -49,11 +49,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { /** - * for time + * default skip this for saving time */ test("test ImageRecordIter") { //get data - "./scripts/get_cifar_data.sh" ! + //"./scripts/get_cifar_data.sh" ! val params = Map( "path_imgrec" -> "data/cifar/train.rec", @@ -66,11 +66,30 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "preprocess_threads" -> "4", "prefetch_buffer" -> "1" ) - val img_iter = IO.createIterator("ImageRecordIter", params) - img_iter.reset() - while(img_iter.iterNext()) { - val batch = img_iter.next() + val imgRecIter = IO.createIterator("ImageRecordIter", params) + val nBatch = 500 + var batchCount = 0 + imgRecIter.reset() + while(imgRecIter.iterNext()) { + val batch = imgRecIter.next() + batchCount += 1 } + // test loop + assert(batchCount === nBatch) + // test reset + imgRecIter.reset() + imgRecIter.iterNext() + val label0 = imgRecIter.getLabel().toArray + val data0 = imgRecIter.getData().toArray + imgRecIter.iterNext() + imgRecIter.iterNext() + imgRecIter.iterNext() + imgRecIter.reset() + imgRecIter.iterNext() + val label1 = imgRecIter.getLabel().toArray + val data1 = imgRecIter.getData().toArray + assert(label0 === label1) + assert(data0 === data1) } // test("test NDarryIter") { From cc54bca31fc036e981713f9e24e08efa73bce10c Mon Sep 17 00:00:00 2001 From: yanqingmen Date: Wed, 27 Jan 2016 21:33:31 +0800 Subject: [PATCH 7/8] small change from scalastyle --- scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 54a7983d580d..3fe00f7043e6 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -17,6 +17,7 @@ object IO { def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter") def ImageRecordIter: IterCreateFunc = iterCreateFuncs("ImageRecordIter") + def CSVIter: IterCreateFunc = iterCreateFuncs("CSVIter") /** * create iterator via iterName and params @@ -69,7 +70,7 @@ object IO { // Convert data into canonical form. private def initData(data: List[NDArray], allowEmpty: Boolean, defaultName: String) = { require(data != null || allowEmpty) - //TODO + // TODO } } @@ -151,7 +152,7 @@ class MXDataIter(val handle: DataIterHandle, val labelName: String = "label") extends DataIter { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) - //get first batch + // get first batch reset() iterNext() val firstBatch = next() From d86cc4284cd11852258ee7267a21e3673bc5f29a Mon Sep 17 00:00:00 2001 From: yanqingmen Date: Wed, 27 Jan 2016 22:13:31 +0800 Subject: [PATCH 8/8] small change for IOSuite --- .../test/scala/ml/dmlc/mxnet/IOSuite.scala | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 1a113145253f..ae184c92177f 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -21,6 +21,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { ) val mnistIter = IO.MNISTIter(params) + // test provideData + val provideData = mnistIter.provideData + val provideLabel = mnistIter.provideLabel + assert(provideData("data") === Array(100, 784)) + assert(provideLabel("label") === Array(100)) // test_loop mnistIter.reset() val nBatch = 600 @@ -52,7 +57,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { * default skip this test for saving time */ // test("test ImageRecordIter") { -// //get data +// // get data // "./scripts/get_cifar_data.sh" ! // // val params = Map( @@ -69,6 +74,12 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // val imgRecIter = IO.ImageRecordIter(params) // val nBatch = 500 // var batchCount = 0 +// // test provideData +// val provideData = imgRecIter.provideData +// val provideLabel = imgRecIter.provideLabel +// assert(provideData("data") === Array(100, 3, 28, 28)) +// assert(provideLabel("label") === Array(100)) +// // imgRecIter.reset() // while(imgRecIter.iterNext()) { // val batch = imgRecIter.next() @@ -79,15 +90,15 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // // test reset // imgRecIter.reset() // imgRecIter.iterNext() -// val label0 = imgRecIter.getLabel().toArray -// val data0 = imgRecIter.getData().toArray +// val label0 = imgRecIter.getLabel().head.toArray +// val data0 = imgRecIter.getData().head.toArray // imgRecIter.iterNext() // imgRecIter.iterNext() // imgRecIter.iterNext() // imgRecIter.reset() // imgRecIter.iterNext() -// val label1 = imgRecIter.getLabel().toArray -// val data1 = imgRecIter.getData().toArray +// val label1 = imgRecIter.getLabel().head.toArray +// val data1 = imgRecIter.getData().head.toArray // assert(label0 === label1) // assert(data0 === data1) // }