From 56f47e020319fb8697e1ddd3551b3a8fdacc1b22 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 8 Jun 2018 17:39:31 -0700 Subject: [PATCH 1/3] improve NDArrayIter to have Builder and ability to specifying names --- .../org/apache/mxnet/io/NDArrayIter.scala | 90 +++++++++++++------ .../test/scala/org/apache/mxnet/IOSuite.scala | 8 +- 2 files changed, 70 insertions(+), 28 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 51089382097b..ed3c5adffd2e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -23,6 +23,7 @@ import org.apache.mxnet.Base._ import org.apache.mxnet._ import org.slf4j.LoggerFactory +import scala.annotation.varargs import scala.collection.immutable.ListMap /** @@ -38,15 +39,23 @@ import scala.collection.immutable.ListMap * the size of data does not match batch_size. Roll over is intended * for training and can cause problems if used for prediction. */ -class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, - private val dataBatchSize: Int = 1, shuffle: Boolean = false, - lastBatchHandle: String = "pad", - dataName: String = "data", labelName: String = "label") extends DataIter { - private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) +class NDArrayIter(data: IndexedSeq[(String, NDArray)], + label: IndexedSeq[(String, NDArray)], + private val dataBatchSize: Int, shuffle: Boolean, + lastBatchHandle: String) extends DataIter { + + def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, + dataBatchSize: Int = 1, shuffle: Boolean = false, + lastBatchHandle: String = "pad", + dataName: String = "data", labelName: String = "label") { + this(IO.initData(data, allowEmpty = false, dataName), + IO.initData(label, allowEmpty = true, labelName), + dataBatchSize, shuffle, lastBatchHandle) + } + private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) - private val (_dataList: IndexedSeq[NDArray], - _labelList: IndexedSeq[NDArray]) = { + val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = { // data should not be null and size > 0 require(data != null && data.size > 0, "data should not be null and data.size should not be zero") @@ -55,17 +64,17 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index "label should not be null. Use IndexedSeq.empty if there are no labels") // shuffle is not supported currently - require(shuffle == false, "shuffle is not supported currently") + require(!shuffle, "shuffle is not supported currently") // discard final part if lastBatchHandle equals discard if (lastBatchHandle.equals("discard")) { - val dataSize = data(0).shape(0) + val dataSize = data(0)._2.shape(0) require(dataBatchSize <= dataSize, "batch_size need to be smaller than data size when not padding.") val keepSize = dataSize - dataSize % dataBatchSize - val dataList = data.map(ndArray => {ndArray.slice(0, keepSize)}) + val dataList = data.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) } if (!label.isEmpty) { - val labelList = label.map(ndArray => {ndArray.slice(0, keepSize)}) + val labelList = label.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) } (dataList, labelList) } else { (dataList, label) @@ -75,13 +84,9 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index } } - - val initData: IndexedSeq[(String, NDArray)] = IO.initData(_dataList, false, dataName) - val initLabel: IndexedSeq[(String, NDArray)] = IO.initData(_labelList, true, labelName) - val numData = _dataList(0).shape(0) - val numSource = initData.size - var cursor = -dataBatchSize - + val numData = initData(0)._2.shape(0) + val numSource: MXUint = initData.size + private var cursor = -dataBatchSize private val (_provideData: ListMap[String, Shape], _provideLabel: ListMap[String, Shape]) = { @@ -112,8 +117,8 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index * reset the iterator */ override def reset(): Unit = { - if (lastBatchHandle.equals("roll_over") && cursor>numData) { - cursor = -dataBatchSize + (cursor%numData)%dataBatchSize + if (lastBatchHandle.equals("roll_over") && cursor > numData) { + cursor = -dataBatchSize + (cursor%numData) % dataBatchSize } else { cursor = -dataBatchSize } @@ -154,16 +159,16 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index newArray } - private def _getData(data: IndexedSeq[NDArray]): IndexedSeq[NDArray] = { + private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = { require(cursor < numData, "DataIter needs reset.") if (data == null) { null } else { if (cursor + dataBatchSize <= numData) { - data.map(ndArray => {ndArray.slice(cursor, cursor + dataBatchSize)}).toIndexedSeq + data.map { case (_, ndArray) => ndArray.slice(cursor, cursor + dataBatchSize) } } else { // padding - data.map(_padData).toIndexedSeq + data.map { case (_, ndArray) => _padData(ndArray) } } } } @@ -173,7 +178,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index * @return the data of current batch */ override def getData(): IndexedSeq[NDArray] = { - _getData(_dataList) + _getData(initData) } /** @@ -181,7 +186,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index * @return the label of current batch */ override def getLabel(): IndexedSeq[NDArray] = { - _getData(_labelList) + _getData(initLabel) } /** @@ -189,7 +194,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index * @return */ override def getIndex(): IndexedSeq[Long] = { - (cursor.toLong to (cursor + dataBatchSize).toLong).toIndexedSeq + cursor.toLong to (cursor + dataBatchSize).toLong } /** @@ -213,3 +218,36 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index override def batchSize: Int = dataBatchSize } + +object NDArrayIter { + class Builder() { + private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty + private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty + private var dataBatchSize: Int = 1 + private var lastBatchHandle: String = "pad" + + def addData(name: String, data: NDArray): Builder = { + this.data = this.data ++ IndexedSeq((name, data)) + this + } + + def addLabel(name: String, label: NDArray): Builder = { + this.label = this.label ++ IndexedSeq((name, label)) + this + } + + def setBatchSize(batchSize: Int): Builder = { + this.dataBatchSize = batchSize + this + } + + def setLastBatchHandle(lastBatchHandle: String): Builder = { + this.lastBatchHandle = lastBatchHandle + this + } + + def build(): NDArrayIter = { + new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle) + } + } +} diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 0f4b7c0e7a3d..1b922b3c05b6 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -24,7 +24,7 @@ import scala.sys.process._ class IOSuite extends FunSuite with BeforeAndAfterAll { - private var tu = new TestUtil + private val tu = new TestUtil test("test MNISTIter & MNISTPack") { // get data @@ -258,7 +258,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch0) // test discard - val dataIter1 = new NDArrayIter(data, label, 128, false, "discard") + val dataIter1 = new NDArrayIter.Builder() + .addData("data0", data(0)).addData("data1", data(1)) + .addLabel("label", label(0)) + .setBatchSize(128) + .setLastBatchHandle("discard").build() val nBatch1 = 7 batchCount = 0 while(dataIter1.hasNext) { From 0b4efe9bf3ca935ac2060dac914cf1274521519e Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 12 Jun 2018 23:00:54 -0700 Subject: [PATCH 2/3] javadoc for NDArrayIter --- .../org/apache/mxnet/io/NDArrayIter.scala | 45 ++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index ed3c5adffd2e..dae0a6a573a1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -23,13 +23,12 @@ import org.apache.mxnet.Base._ import org.apache.mxnet._ import org.slf4j.LoggerFactory -import scala.annotation.varargs import scala.collection.immutable.ListMap /** * NDArrayIter object in mxnet. Taking NDArray to get dataiter. * - * @param data NDArrayIter supports single or multiple data and label. + * @param data Specify the data as well as the name. NDArrayIter supports single or multiple data and label. * @param label Same as data, but is not fed to the model during testing. * @param dataBatchSize Batch Size * @param shuffle Whether to shuffle the data @@ -44,6 +43,18 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, lastBatchHandle: String) extends DataIter { + /** + * @param data Specify the data. Data names will be data_0, data_1, ..., etc. + * @param label Same as data, but is not fed to the model during testing. + * Label names will be label_0, label_1, ..., etc. + * @param dataBatchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", @@ -220,32 +231,62 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], } object NDArrayIter { + + /** + * Builder class for NDArrayIter. + */ class Builder() { private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" + /** + * Add one data input with its name. + * @param name Data name. + * @param data Data nd-array. + * @return The builder object itself. + */ def addData(name: String, data: NDArray): Builder = { this.data = this.data ++ IndexedSeq((name, data)) this } + /** + * Add one label input with its name. + * @param name Label name. + * @param label Label nd-array. + * @return The builder object itself. + */ def addLabel(name: String, label: NDArray): Builder = { this.label = this.label ++ IndexedSeq((name, label)) this } + /** + * Set the batch size of the iterator. + * @param batchSize batch size. + * @return The builder object itself. + */ def setBatchSize(batchSize: Int): Builder = { this.dataBatchSize = batchSize this } + /** + * How to handle the last batch. + * @param lastBatchHandle Can be "pad", "discard" or "roll_over". + * @return The builder object itself. + */ def setLastBatchHandle(lastBatchHandle: String): Builder = { this.lastBatchHandle = lastBatchHandle this } + /** + * Build the NDArrayIter object. + * @return the built object. + */ def build(): NDArrayIter = { new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle) } From 1b14e6cc6263f37d6d034ff655b0081a78ce7442 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 12 Jun 2018 23:44:55 -0700 Subject: [PATCH 3/3] fix lint --- .../core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index dae0a6a573a1..70c648778870 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -28,7 +28,8 @@ import scala.collection.immutable.ListMap /** * NDArrayIter object in mxnet. Taking NDArray to get dataiter. * - * @param data Specify the data as well as the name. NDArrayIter supports single or multiple data and label. + * @param data Specify the data as well as the name. + * NDArrayIter supports single or multiple data and label. * @param label Same as data, but is not fed to the model during testing. * @param dataBatchSize Batch Size * @param shuffle Whether to shuffle the data