From 21aaea8bdffb5ca6b66313c31ec88bcc2c40a933 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 20 Jul 2018 15:08:15 -0700 Subject: [PATCH 01/25] add dataDesc --- .../src/main/scala/org/apache/mxnet/IO.scala | 49 ++++++++++- .../scala/org/apache/mxnet/RecordIO.scala | 5 +- .../org/apache/mxnet/io/MXDataIter.scala | 55 ++++++++++-- .../org/apache/mxnet/io/NDArrayIter.scala | 64 ++++++++++++-- .../org/apache/mxnet/io/PrefetchingIter.scala | 84 ++++++++++++++++--- .../org/apache/mxnet/io/ResizeIter.scala | 29 ++++++- .../multitask/ExampleMultiTask.scala | 33 +++++--- .../apache/mxnetexamples/rnn/BucketIo.scala | 43 +++++++--- .../mxnet/spark/io/LabeledPointIter.scala | 22 ++++- .../mxnet/spark/io/LongLivingDataBatch.scala | 8 +- .../org/apache/mxnet/spark/io/PointIter.scala | 22 ++++- 11 files changed, 356 insertions(+), 58 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 47fd4eee939a..607a0c782619 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -140,7 +140,9 @@ class DataBatch(val data: IndexedSeq[NDArray], // use ListMap to indicate the order of data/label loading // (must match the order of input data/label) private val providedData: ListMap[String, Shape] = null, - private val providedLabel: ListMap[String, Shape] = null) { + private val providedLabel: ListMap[String, Shape] = null, + val dtype: DType = Base.MX_REAL_TYPE, + val layout: String = "NCHW") { /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -170,6 +172,8 @@ object DataBatch { private var label: IndexedSeq[NDArray] = null private var index: IndexedSeq[Long] = null private var pad: Int = 0 + private var layout: String = "NCHW" + private var dtype: DType = Base.MX_REAL_TYPE private var bucketKey: AnyRef = null private var datatShapes: ListMap[String, Shape] = null private var labelShapes: ListMap[String, Shape] = null @@ -216,6 +220,26 @@ object DataBatch { this } + /** + * Set the dtype. + * @param dtype The dtype of the label, default is Float32 + * @return this + */ + def setDType(dtype: DType): Builder = { + this.dtype = dtype + this + } + + /** + * Set the layout. + * @param layout The layout of the label, default is NCHW + * @return this + */ + def setLayout(layout: String): Builder = { + this.layout = layout + this + } + /** * Set the bucket key, used for bucketing module. * @param bucketKey the bucket key related to this batch. @@ -258,7 +282,7 @@ object DataBatch { def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes) + new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, dtype, layout) } } } @@ -280,7 +304,8 @@ abstract class DataIter extends Iterator[DataBatch] { */ @throws(classOf[NoSuchElementException]) def next(): DataBatch = { - new DataBatch(getData(), getLabel(), getIndex(), getPad()) + new DataBatch(getData(), getLabel(), getIndex(), getPad(), + dtype = getDType(), layout = getLayout()) } /** @@ -302,6 +327,18 @@ abstract class DataIter extends Iterator[DataBatch] { */ def getPad(): Int + /** + * Get the DType + * @return DType of the DataIter + */ + def getDType(): DType + + /** + * Get the layout + * @return layout of the DataIter + */ + def getLayout(): String + /** * Get the index of current batch * @return the index of current batch @@ -314,6 +351,12 @@ abstract class DataIter extends Iterator[DataBatch] { // The name and shape of label provided by this iterator def provideLabel: ListMap[String, Shape] + // Provide type:DataDesc of the data + def provideDataDesc: IndexedSeq[DataDesc] + + // Provide type:DataDesc of the label + def provideLabelDesc: IndexedSeq[DataDesc] + // For bucketing io only // The bucket key for the default symbol. def defaultBucketKey: AnyRef = null diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala index ee3e950512e7..578f00a76f9a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala @@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream /** * Scala interface for read/write RecordIO data format - * - * @author Depeng Liang - * * @param uri, path to recordIO file. * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing. */ @@ -144,7 +141,7 @@ object MXRecordIO { * * @author Depeng Liang * - * @param idx_path, path to index file + * @param idxPath, path to index file * @param uri, path to recordIO file. * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing. * @param keyType, data type for keys. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 2a0c333ebf10..f83bcc4207b2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -18,7 +18,8 @@ package org.apache.mxnet.io import org.apache.mxnet.Base._ -import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, WarnIfNotDisposed} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.apache.mxnet.IO._ import org.slf4j.LoggerFactory @@ -42,20 +43,41 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, private var currentBatch: DataBatch = null private val (_provideData: ListMap[String, Shape], - _provideLabel: ListMap[String, Shape], - _batchSize: Int) = + _provideLabel: ListMap[String, Shape]) = if (hasNext) { iterNext() val data = currentBatch.data(0) val label = currentBatch.label(0) // properties - val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), data.shape(0)) + val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape)) + currentBatch.dispose() + reset() + res + } else { + (null, null) + } + + private val (_provideDataDesc: IndexedSeq[DataDesc], + _provideLabelDesc: IndexedSeq[DataDesc], + _batchSize: Int) = { + if (hasNext) { + iterNext() + val data = currentBatch.data(0) + val label = currentBatch.label(0) + val dType = currentBatch.dtype + val layout = currentBatch.layout + // properties + val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, layout)), + IndexedSeq(new DataDesc(labelName, label.shape, dType, layout)), + data.shape(0)) currentBatch.dispose() reset() res } else { (null, null, 0) } + } + private var disposed = false protected def isDisposed = disposed @@ -101,10 +123,11 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, private def iterNext(): Boolean = { val next = new RefInt checkCall(_LIB.mxDataIterNext(handle, next)) - currentBatch = null if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), index = getIndex(), pad = getPad()) + } else { + currentBatch = null } next.value > 0 } @@ -151,12 +174,34 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, out.value } + /** + * Get the DType + * @return DType + */ + def getDType(): DType = { + currentBatch.dtype + } + + /** + * Get the layout + * @return layout + */ + def getLayout(): String = { + currentBatch.layout + } + // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = _provideLabel + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def hasNext: Boolean = { if (currentBatch != null) { true diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 10461315c198..d34390bacdca 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -20,6 +20,7 @@ package org.apache.mxnet.io import java.util.NoSuchElementException import org.apache.mxnet.Base._ +import org.apache.mxnet.DType.DType import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -42,7 +43,7 @@ import scala.collection.immutable.ListMap class NDArrayIter(data: IndexedSeq[(String, NDArray)], label: IndexedSeq[(String, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, - lastBatchHandle: String) extends DataIter { + lastBatchHandle: String, dtype: DType, layout: String) extends DataIter { /** * @param data Specify the data. Data names will be data_0, data_1, ..., etc. @@ -59,10 +60,11 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", - dataName: String = "data", labelName: String = "label") { + dataName: String = "data", labelName: String = "label", + dType: DType = MX_REAL_TYPE, layout: String = "NCHW") { this(IO.initData(data, allowEmpty = false, dataName), IO.initData(label, allowEmpty = true, labelName), - dataBatchSize, shuffle, lastBatchHandle) + dataBatchSize, shuffle, lastBatchHandle, dType, layout) } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) @@ -107,6 +109,13 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], (pData, pLabel) } + private val (_provideDataDesc: IndexedSeq[DataDesc], + _provideLabelDesc: IndexedSeq[DataDesc]) = { + val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout)) + val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout)) + (pData, pLabel) + } + /** * get shape via dataBatchSize * @param dataItem @@ -148,7 +157,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], override def next(): DataBatch = { if (hasNext) { cursor += dataBatchSize - new DataBatch(getData(), getLabel(), getIndex(), getPad()) + new DataBatch(getData(), getLabel(), getIndex(), getPad(), + dtype = getDType(), layout = getLayout()) } else { throw new NoSuchElementException } @@ -223,12 +233,34 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], } } + /** + * Get the DType + * @return DType + */ + def getDType(): DType = { + dtype + } + + /** + * Get the layout + * @return layout + */ + def getLayout(): String = { + layout + } + // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = _provideLabel + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def batchSize: Int = dataBatchSize } @@ -242,6 +274,8 @@ object NDArrayIter { private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" + private var layout: String = "NCHW" + private var dtype: DType = Base.MX_REAL_TYPE /** * Add one data input with its name. @@ -285,12 +319,32 @@ object NDArrayIter { this } + /** + * Set the dtype. + * @param dtype The dtype of the label, default is Float32 + * @return this + */ + def setDType(dtype: DType): Builder = { + this.dtype = dtype + this + } + + /** + * Set the layout. + * @param layout The layout of the label, default is NCHW + * @return this + */ + def setLayout(layout: String): Builder = { + this.layout = layout + this + } + /** * Build the NDArrayIter object. * @return the built object. */ def build(): NDArrayIter = { - new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle) + new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, dtype, layout) } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index c0c0d1793b54..f8f589f5faa5 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -17,10 +17,12 @@ package org.apache.mxnet.io -import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape} +import org.apache.mxnet._ import org.slf4j.LoggerFactory import java.util.concurrent.Semaphore +import org.apache.mxnet.DType.DType + import scala.collection.immutable.ListMap /** @@ -68,6 +70,42 @@ class PrefetchingIter( } } + private val _provideDataDesc: IndexedSeq[DataDesc] = { + if (dataNames == null) { + iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } else { + iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2)) + .map(m => + m._1.map(t => + new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout) + ) + ) + .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } + } + + private val _provideLabelDesc: IndexedSeq[DataDesc] = { + if (dataNames == null) { + iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } else { + iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2)) + .map(m => + m._1.map(t => + new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout) + ) + ) + .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } + } + private val _batchSize: Int = this._provideData.toList(0)._2(0) private val dataReady: IndexedSeq[Semaphore] = (0 until iters.length).map(i => new Semaphore(0)) @@ -132,19 +170,41 @@ class PrefetchingIter( */ override def getIndex(): IndexedSeq[Long] = currentBatch.index - // The name and shape of label provided by this iterator - override def provideLabel: ListMap[String, Shape] = this._provideLabel - /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): Int = this.currentBatch.pad + /** + * Get the DType + * @return DType + */ + def getDType(): DType = { + currentBatch.dtype + } + + /** + * Get the layout + * @return layout + */ + def getLayout(): String = { + currentBatch.layout + } + + // The name and shape of label provided by this iterator + override def provideLabel: ListMap[String, Shape] = this._provideLabel + // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this._provideData + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def hasNext: Boolean = { for (e <- dataReady) e.acquire() if (nextBatch(0) == null) { @@ -161,9 +221,11 @@ class PrefetchingIter( val datas = for (batch <- nextBatch) yield batch.data val labels = for (batch <- nextBatch) yield batch.label currentBatch = new DataBatch(datas.toIndexedSeq.flatten, - labels.toIndexedSeq.flatten, - nextBatch(0).index, - nextBatch(0).pad) + labels.toIndexedSeq.flatten, + nextBatch(0).index, + nextBatch(0).pad, + layout = nextBatch(0).layout, + dtype = nextBatch(0).dtype) for (e <- dataTaken) e.release() true } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index 75d88d1ae72f..228ba72c97ed 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -19,7 +19,8 @@ package org.apache.mxnet.io import java.util.NoSuchElementException -import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.slf4j.LoggerFactory import scala.collection.immutable.ListMap @@ -128,6 +129,22 @@ class ResizeIter( currentBatch.pad } + /** + * Get the DType + * @return DType + */ + def getDType(): DType = { + currentBatch.dtype + } + + /** + * Get the layout + * @return layout + */ + def getLayout(): String = { + currentBatch.layout + } + override def batchSize: Int = { dataIter.batchSize } @@ -141,4 +158,14 @@ class ResizeIter( override def provideLabel: ListMap[String, Shape] = { dataIter.provideLabel } + + // The name and shape of data provided by this iterator + override def provideDataDesc: IndexedSeq[DataDesc] = { + dataIter.provideDataDesc + } + + // The name and shape of label provided by this iterator + override def provideLabelDesc: IndexedSeq[DataDesc] = { + dataIter.provideLabelDesc + } } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 9df2bcc0566d..8decdfe8b8d4 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -25,14 +25,9 @@ import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ import org.apache.commons.io.FileUtils -import org.apache.mxnet.Symbol -import org.apache.mxnet.DataIter -import org.apache.mxnet.DataBatch -import org.apache.mxnet.NDArray -import org.apache.mxnet.Shape -import org.apache.mxnet.EvalMetric -import org.apache.mxnet.Context -import org.apache.mxnet.Xavier + +import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, NDArray, Shape, Symbol, Xavier} +import org.apache.mxnet.DType.DType import org.apache.mxnet.optimizer.RMSProp import org.apache.mxnet.Executor import org.apache.mxnetexamples.Util @@ -70,9 +65,9 @@ object ExampleMultiTask { val batch = this.dataIter.next() val label = batch.label(0) new DataBatch(batch.data, - IndexedSeq(label, label), - batch.index, - batch.pad) + IndexedSeq(label, label), + batch.index, + batch.pad, dtype = batch.dtype, layout = batch.layout) } else { throw new NoSuchElementException } @@ -114,6 +109,16 @@ object ExampleMultiTask { "softmax2_label" -> provideLabel(0)._2) } + // The name and shape of label provided by this iterator + override def provideLabelDesc: IndexedSeq[DataDesc] = { + val head = this.dataIter.provideLabelDesc(0) + // Different labels should be used here for actual application + IndexedSeq( + new DataDesc("softmax1_label", head.shape, head.dtype, head.layout), + new DataDesc("softmax2_label", head.shape, head.dtype, head.layout) + ) + } + /** * get the number of padding examples * in current batch @@ -121,9 +126,15 @@ object ExampleMultiTask { */ override def getPad(): Int = this.dataIter.getPad() + override def getDType(): DType = this.dataIter.getDType() + + override def getLayout(): String = this.dataIter.getLayout() + // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this.dataIter.provideData + override def provideDataDesc: IndexedSeq[DataDesc] = this.dataIter.provideDataDesc + override def hasNext: Boolean = this.dataIter.hasNext } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index f0eae6890c52..e37a3265d322 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -18,8 +18,10 @@ package org.apache.mxnetexamples.rnn -import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.slf4j.LoggerFactory + import scala.collection.immutable.ListMap import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -95,7 +97,8 @@ object BucketIo { path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int], _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, - readContent: ReadContent = defaultReadContent) extends DataIter { + readContent: ReadContent = defaultReadContent, layout: String = "NT", + dtype : DType = DType.Float32) extends DataIter { private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) @@ -165,8 +168,18 @@ object BucketIo { private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey)) tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2)) } + private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, _defaultBucketKey)) + private val _provideDataDesc = { + val tmp = IndexedSeq(new DataDesc("data", + Shape(_batchSize, _defaultBucketKey), dtype, layout)) + tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dtype, layout)) + } + + private val _provideLabelDesc = IndexedSeq(new DataDesc("softmax_label", + Shape(_batchSize, _defaultBucketKey), dtype, layout)) + private var iBucket = 0 override def next(): DataBatch = { @@ -197,7 +210,7 @@ object BucketIo { getIndex(), getPad(), this.buckets(bucketIdx).asInstanceOf[AnyRef], - batchProvideData, batchProvideLabel) + batchProvideData, batchProvideLabel, dtype, layout) } /** @@ -228,19 +241,29 @@ object BucketIo { */ override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]() - // The name and shape of label provided by this iterator - override def provideLabel: ListMap[String, Shape] = this._provideLabel - /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): Int = 0 + override def getDType(): DType = dtype + + override def getLayout(): String = layout + + // The name and shape of label provided by this iterator + override def provideLabel: ListMap[String, Shape] = this._provideLabel + // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this._provideData + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def hasNext: Boolean = { iBucket < bucketPlan.length } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index adc723ecdacb..db491b497b93 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.spark.io -import org.apache.mxnet.{DataBatch, NDArray, Shape, DataIter} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.apache.spark.mllib.regression.LabeledPoint import scala.collection.immutable.ListMap @@ -32,7 +33,9 @@ class LabeledPointIter private[mxnet]( private val dimension: Shape, private val _batchSize: Int, private val dataName: String = "data", - private val labelName: String = "label") extends DataIter { + private val labelName: String = "label", + private val dtype: DType = DType.Float32, + private val layout: String = "NCHW") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -72,7 +75,8 @@ class LabeledPointIter private[mxnet]( } val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( - IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad) + IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, + layout, dtype) cache += dataBatch dataBatch } @@ -124,6 +128,14 @@ class LabeledPointIter private[mxnet]( ListMap(dataName -> dataShape) } + override def provideDataDesc: IndexedSeq[DataDesc] = { + IndexedSeq(new DataDesc(dataName, dataShape, dtype, layout)) + } + + override def provideLabelDesc: IndexedSeq[DataDesc] = { + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, layout)) + } + /** * Get the number of padding examples * in current batch @@ -131,6 +143,10 @@ class LabeledPointIter private[mxnet]( */ override def getPad(): Int = 0 + override def getDType(): DType = dtype + + override def getLayout(): String = layout + override def batchSize: Int = _batchSize override def hasNext: Boolean = { diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index 339f7e2d76ca..d062a81b9bcc 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.spark.io -import org.apache.mxnet.{NDArray, DataBatch} +import org.apache.mxnet.DType.DType +import org.apache.mxnet.{DataBatch, NDArray} /** * Dispose only when 'disposeForce' called @@ -27,7 +28,10 @@ class LongLivingDataBatch( override val data: IndexedSeq[NDArray], override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], - override val pad: Int) extends DataBatch(data, label, index, pad) { + override val pad: Int, + override val layout: String, + override val dtype: DType) extends DataBatch(data, label, index, pad, + layout = layout, dtype = dtype) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index 21329291cfb5..c0d898c72fc7 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.spark.io -import org.apache.mxnet.{NDArray, DataBatch, DataIter, Shape} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.apache.spark.mllib.linalg.Vector import scala.collection.immutable.ListMap @@ -32,7 +33,9 @@ class PointIter private[mxnet]( private val dimension: Shape, private val _batchSize: Int, private val dataName: String = "data", - private val labelName: String = "label") extends DataIter { + private val labelName: String = "label", + private val dtype: DType = DType.Float32, + private val layout: String = "NCHW") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -71,7 +74,8 @@ class PointIter private[mxnet]( } val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( - IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad) + IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, + layout, dtype) cache += dataBatch dataBatch } @@ -123,6 +127,14 @@ class PointIter private[mxnet]( ListMap(dataName -> dataShape) } + override def provideDataDesc: IndexedSeq[DataDesc] = { + IndexedSeq(new DataDesc(dataName, dataShape, dtype, layout)) + } + + override def provideLabelDesc: IndexedSeq[DataDesc] = { + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, layout)) + } + /** * Get the number of padding examples * in current batch @@ -130,6 +142,10 @@ class PointIter private[mxnet]( */ override def getPad(): Int = 0 + override def getDType(): DType = dtype + + override def getLayout(): String = layout + override def batchSize: Int = _batchSize override def hasNext: Boolean = { From 0ca6088c49a9b6300d9e8a3d2598997383a204ae Mon Sep 17 00:00:00 2001 From: Qing Date: Sun, 22 Jul 2018 10:55:56 -0700 Subject: [PATCH 02/25] Add amend --- .../core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index f83bcc4207b2..4c6a64d99034 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -125,7 +125,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, checkCall(_LIB.mxDataIterNext(handle, next)) if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), - index = getIndex(), pad = getPad()) + index = getIndex(), pad = getPad(), + dtype = currentBatch.dtype, layout = currentBatch.layout) } else { currentBatch = null } From 88a00433f5176a7391a06b35bc0584e7c7207130 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 24 Jul 2018 13:45:28 -0700 Subject: [PATCH 03/25] add changes with dataLayout and labelLayout --- .../src/main/scala/org/apache/mxnet/IO.scala | 28 ++++++++++------ .../org/apache/mxnet/io/MXDataIter.scala | 23 ++++++------- .../org/apache/mxnet/io/NDArrayIter.scala | 32 +++++++++++-------- .../org/apache/mxnet/io/PrefetchingIter.scala | 7 ++-- .../org/apache/mxnet/io/ResizeIter.scala | 4 +-- .../test/scala/org/apache/mxnet/IOSuite.scala | 23 +++++++++---- .../scala/org/apache/mxnet/ModuleSuite.scala | 6 ++-- .../multitask/ExampleMultiTask.scala | 5 +-- .../apache/mxnetexamples/rnn/BucketIo.scala | 15 +++++---- .../mxnet/spark/io/LabeledPointIter.scala | 11 ++++--- .../mxnet/spark/io/LongLivingDataBatch.scala | 5 +-- .../org/apache/mxnet/spark/io/PointIter.scala | 11 ++++--- 12 files changed, 103 insertions(+), 67 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 607a0c782619..57eca2140664 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -106,7 +106,10 @@ object IO { checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) val dataName = params.getOrElse("data_name", "data") val labelName = params.getOrElse("label_name", "label") - new MXDataIter(out.value, dataName, labelName) + val dataLayout = params.getOrElse("dataLayout", "NCHW") + val labelLayout = params.getOrElse("labelLayout", "N") + new MXDataIter(out.value, dataName, labelName, + dataLayout = dataLayout, labelLayout = labelLayout) } // Convert data into canonical form. @@ -142,7 +145,8 @@ class DataBatch(val data: IndexedSeq[NDArray], private val providedData: ListMap[String, Shape] = null, private val providedLabel: ListMap[String, Shape] = null, val dtype: DType = Base.MX_REAL_TYPE, - val layout: String = "NCHW") { + val dataLayout: String = "NCHW", + val labelLayout: String = "N") { /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -172,7 +176,8 @@ object DataBatch { private var label: IndexedSeq[NDArray] = null private var index: IndexedSeq[Long] = null private var pad: Int = 0 - private var layout: String = "NCHW" + private var dataLayout: String = "NCHW" + private var labelLayout: String = "N" private var dtype: DType = Base.MX_REAL_TYPE private var bucketKey: AnyRef = null private var datatShapes: ListMap[String, Shape] = null @@ -232,11 +237,13 @@ object DataBatch { /** * Set the layout. - * @param layout The layout of the label, default is NCHW + * @param dataLayout The layout of the data, default is NCHW + * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(layout: String): Builder = { - this.layout = layout + def setLayout(dataLayout: String, labelLayout: String): Builder = { + this.dataLayout = dataLayout + this.labelLayout = labelLayout this } @@ -282,7 +289,8 @@ object DataBatch { def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, dtype, layout) + new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, + dtype, dataLayout, labelLayout) } } } @@ -305,7 +313,7 @@ abstract class DataIter extends Iterator[DataBatch] { @throws(classOf[NoSuchElementException]) def next(): DataBatch = { new DataBatch(getData(), getLabel(), getIndex(), getPad(), - dtype = getDType(), layout = getLayout()) + dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2) } /** @@ -335,9 +343,9 @@ abstract class DataIter extends Iterator[DataBatch] { /** * Get the layout - * @return layout of the DataIter + * @return data and label layout of the DataIter */ - def getLayout(): String + def getLayout(): (String, String) /** * Get the index of current batch diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 4c6a64d99034..5d8fd6d07512 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -32,7 +32,10 @@ import scala.collection.mutable.ListBuffer */ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, dataName: String = "data", - labelName: String = "label") + labelName: String = "label", + dtype: DType = DType.Float32, + dataLayout: String = "NCHW", + labelLayout: String = "N") extends DataIter with WarnIfNotDisposed { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) @@ -65,10 +68,11 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, val data = currentBatch.data(0) val label = currentBatch.label(0) val dType = currentBatch.dtype - val layout = currentBatch.layout + val dataLayout = currentBatch.dataLayout + val labelLayout = currentBatch.labelLayout // properties - val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, layout)), - IndexedSeq(new DataDesc(labelName, label.shape, dType, layout)), + val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, dataLayout)), + IndexedSeq(new DataDesc(labelName, label.shape, dType, labelLayout)), data.shape(0)) currentBatch.dispose() reset() @@ -126,7 +130,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), index = getIndex(), pad = getPad(), - dtype = currentBatch.dtype, layout = currentBatch.layout) + dtype = getDType(), dataLayout = getLayout()._1, + labelLayout = getLayout()._2) } else { currentBatch = null } @@ -179,17 +184,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, * Get the DType * @return DType */ - def getDType(): DType = { - currentBatch.dtype - } + def getDType(): DType = dtype /** * Get the layout * @return layout */ - def getLayout(): String = { - currentBatch.layout - } + def getLayout(): (String, String) = (dataLayout, labelLayout) // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index d34390bacdca..53f7c352c636 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -43,7 +43,8 @@ import scala.collection.immutable.ListMap class NDArrayIter(data: IndexedSeq[(String, NDArray)], label: IndexedSeq[(String, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, - lastBatchHandle: String, dtype: DType, layout: String) extends DataIter { + lastBatchHandle: String, + dtype: DType, dataLayout: String, labelLayout: String) extends DataIter { /** * @param data Specify the data. Data names will be data_0, data_1, ..., etc. @@ -61,10 +62,11 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label", - dType: DType = MX_REAL_TYPE, layout: String = "NCHW") { + dType: DType = MX_REAL_TYPE, dataLayout: String = "NCHW", + labelLayout: String = "N") { this(IO.initData(data, allowEmpty = false, dataName), IO.initData(label, allowEmpty = true, labelName), - dataBatchSize, shuffle, lastBatchHandle, dType, layout) + dataBatchSize, shuffle, lastBatchHandle, dType, dataLayout, labelLayout) } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) @@ -111,8 +113,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private val (_provideDataDesc: IndexedSeq[DataDesc], _provideLabelDesc: IndexedSeq[DataDesc]) = { - val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout)) - val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout)) + val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, dataLayout)) + val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, labelLayout)) (pData, pLabel) } @@ -158,7 +160,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], if (hasNext) { cursor += dataBatchSize new DataBatch(getData(), getLabel(), getIndex(), getPad(), - dtype = getDType(), layout = getLayout()) + dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2) } else { throw new NoSuchElementException } @@ -245,8 +247,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], * Get the layout * @return layout */ - def getLayout(): String = { - layout + def getLayout(): (String, String) = { + (dataLayout, labelLayout) } // The name and shape of data provided by this iterator @@ -274,7 +276,8 @@ object NDArrayIter { private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" - private var layout: String = "NCHW" + private var dataLayout: String = "NCHW" + private var labelLayout: String = "N" private var dtype: DType = Base.MX_REAL_TYPE /** @@ -331,11 +334,13 @@ object NDArrayIter { /** * Set the layout. - * @param layout The layout of the label, default is NCHW + * @param dataLayout The layout of the data, default is NCHW + * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(layout: String): Builder = { - this.layout = layout + def setLayout(dataLayout: String, labelLayout: String): Builder = { + this.dataLayout = dataLayout + this.labelLayout = labelLayout this } @@ -344,7 +349,8 @@ object NDArrayIter { * @return the built object. */ def build(): NDArrayIter = { - new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, dtype, layout) + new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, + dtype, dataLayout, labelLayout) } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index f8f589f5faa5..3097948b5de9 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -189,8 +189,8 @@ class PrefetchingIter( * Get the layout * @return layout */ - def getLayout(): String = { - currentBatch.layout + def getLayout(): (String, String) = { + (currentBatch.dataLayout, currentBatch.labelLayout) } // The name and shape of label provided by this iterator @@ -224,7 +224,8 @@ class PrefetchingIter( labels.toIndexedSeq.flatten, nextBatch(0).index, nextBatch(0).pad, - layout = nextBatch(0).layout, + dataLayout = nextBatch(0).dataLayout, + labelLayout = nextBatch(0).labelLayout, dtype = nextBatch(0).dtype) for (e <- dataTaken) e.release() true diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index 228ba72c97ed..6e521c219128 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -141,8 +141,8 @@ class ResizeIter( * Get the layout * @return layout */ - def getLayout(): String = { - currentBatch.layout + def getLayout(): (String, String) = { + (currentBatch.dataLayout, currentBatch.labelLayout) } override def batchSize: Int = { diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 1b922b3c05b6..478f834d210b 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -38,7 +38,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10" + "seed" -> "10", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistPack = IO.MNISTPack(params) @@ -99,7 +101,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "data_shape" -> "(3,28,28)", "batch_size" -> "100", "preprocess_threads" -> "4", - "prefetch_buffer" -> "1" + "prefetch_buffer" -> "1", + "dataLayout" -> "NCHW", + "labelLayout" -> "N" ) val imgRecIter = IO.ImageRecordIter(params) val nBatch = 500 @@ -145,7 +149,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10" + "seed" -> "10", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistIter = IO.MNISTIter(params) @@ -182,7 +188,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10" + "seed" -> "10", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistPack1 = IO.MNISTPack(params) @@ -243,7 +251,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val batchLabel = NDArray.ones(Shape(Array(128, 1))) // test pad - val dataIter0 = new NDArrayIter(data, label, 128, false, "pad") + val dataIter0 = new NDArrayIter(data, label, 128, false, "pad", + dataLayout = "NTC", labelLayout = "NT") var batchCount = 0 val nBatch0 = 8 while(dataIter0.hasNext) { @@ -262,6 +271,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { .addData("data0", data(0)).addData("data1", data(1)) .addLabel("label", label(0)) .setBatchSize(128) + .setLayout("NTC", "NT") .setLastBatchHandle("discard").build() val nBatch1 = 7 batchCount = 0 @@ -277,7 +287,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) // test empty label (for prediction) - val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard") + val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard", + dataLayout = "NTC") batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 8234568d7d9f..a73195f77207 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -184,7 +184,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") + IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", + dataLayout = "NCHW", labelLayout = "NCHW") // symbols var x = Symbol.Variable("data") @@ -234,7 +235,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") + IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", + dataLayout = "NCHW", labelLayout = "NCHW") // symbols var x = Symbol.Variable("data") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 8decdfe8b8d4..3cf46045a538 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -67,7 +67,8 @@ object ExampleMultiTask { new DataBatch(batch.data, IndexedSeq(label, label), batch.index, - batch.pad, dtype = batch.dtype, layout = batch.layout) + batch.pad, dtype = batch.dtype, dataLayout = batch.dataLayout, + labelLayout = batch.labelLayout) } else { throw new NoSuchElementException } @@ -128,7 +129,7 @@ object ExampleMultiTask { override def getDType(): DType = this.dataIter.getDType() - override def getLayout(): String = this.dataIter.getLayout() + override def getLayout(): (String, String) = this.dataIter.getLayout() // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this.dataIter.provideData diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index e37a3265d322..21cc146a1b32 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -97,7 +97,9 @@ object BucketIo { path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int], _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, - readContent: ReadContent = defaultReadContent, layout: String = "NT", + readContent: ReadContent = defaultReadContent, + dataLayout: String = "NT", + labelLayout: String = "N", dtype : DType = DType.Float32) extends DataIter { private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) @@ -173,12 +175,12 @@ object BucketIo { private val _provideDataDesc = { val tmp = IndexedSeq(new DataDesc("data", - Shape(_batchSize, _defaultBucketKey), dtype, layout)) - tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dtype, layout)) + Shape(_batchSize, _defaultBucketKey), dtype, dataLayout)) + tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dtype, dataLayout)) } private val _provideLabelDesc = IndexedSeq(new DataDesc("softmax_label", - Shape(_batchSize, _defaultBucketKey), dtype, layout)) + Shape(_batchSize, _defaultBucketKey), dtype, labelLayout)) private var iBucket = 0 @@ -210,7 +212,8 @@ object BucketIo { getIndex(), getPad(), this.buckets(bucketIdx).asInstanceOf[AnyRef], - batchProvideData, batchProvideLabel, dtype, layout) + batchProvideData, batchProvideLabel, getDType(), + getLayout()._1, getLayout()._2) } /** @@ -250,7 +253,7 @@ object BucketIo { override def getDType(): DType = dtype - override def getLayout(): String = layout + override def getLayout(): (String, String) = (dataLayout, labelLayout) // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = this._provideLabel diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index db491b497b93..7fda57a4e3e7 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -35,7 +35,8 @@ class LabeledPointIter private[mxnet]( private val dataName: String = "data", private val labelName: String = "label", private val dtype: DType = DType.Float32, - private val layout: String = "NCHW") extends DataIter { + private val dataLayout: String = "NCHW", + private val labelLayout: String = "N") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -76,7 +77,7 @@ class LabeledPointIter private[mxnet]( val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - layout, dtype) + dataLayout, labelLayout, dtype) cache += dataBatch dataBatch } @@ -129,11 +130,11 @@ class LabeledPointIter private[mxnet]( } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dtype, layout)) + IndexedSeq(new DataDesc(dataName, dataShape, dtype, dataLayout)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, layout)) + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, labelLayout)) } /** @@ -145,7 +146,7 @@ class LabeledPointIter private[mxnet]( override def getDType(): DType = dtype - override def getLayout(): String = layout + override def getLayout(): (String, String) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index d062a81b9bcc..acbcbc7c2b84 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -29,9 +29,10 @@ class LongLivingDataBatch( override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], override val pad: Int, - override val layout: String, + override val dataLayout: String, + override val labelLayout: String, override val dtype: DType) extends DataBatch(data, label, index, pad, - layout = layout, dtype = dtype) { + dataLayout = dataLayout, labelLayout = labelLayout, dtype = dtype) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index c0d898c72fc7..d239e5c641fa 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -35,7 +35,8 @@ class PointIter private[mxnet]( private val dataName: String = "data", private val labelName: String = "label", private val dtype: DType = DType.Float32, - private val layout: String = "NCHW") extends DataIter { + private val dataLayout: String = "NCHW", + private val labelLayout: String = "N") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -75,7 +76,7 @@ class PointIter private[mxnet]( val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - layout, dtype) + dataLayout, labelLayout, dtype) cache += dataBatch dataBatch } @@ -128,11 +129,11 @@ class PointIter private[mxnet]( } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dtype, layout)) + IndexedSeq(new DataDesc(dataName, dataShape, dtype, dataLayout)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, layout)) + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, labelLayout)) } /** @@ -144,7 +145,7 @@ class PointIter private[mxnet]( override def getDType(): DType = dtype - override def getLayout(): String = layout + override def getLayout(): (String, String) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize From b4fada3c32d5fab06aeaeb3976f46b8ff7bb4aeb Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 24 Jul 2018 14:27:59 -0700 Subject: [PATCH 04/25] add depreciate and example changes --- .../core/src/main/scala/org/apache/mxnet/IO.scala | 2 ++ .../mxnetexamples/imclassification/TrainMnist.scala | 8 ++++++-- .../org/apache/mxnetexamples/multitask/Data.scala | 11 ++++++----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 57eca2140664..c3d0b43cd77c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -354,9 +354,11 @@ abstract class DataIter extends Iterator[DataBatch] { def getIndex(): IndexedSeq[Long] // The name and shape of data provided by this iterator + @deprecated def provideData: ListMap[String, Shape] // The name and shape of label provided by this iterator + @deprecated def provideLabel: ListMap[String, Shape] // Provide type:DataDesc of the data diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala index bd0ce45ffe5f..4fce7235ed7f 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala @@ -77,7 +77,9 @@ object TrainMnist { "shuffle" -> "True", "flat" -> flat, "num_parts" -> kv.numWorkers.toString, - "part_index" -> kv.`rank`.toString)) + "part_index" -> kv.`rank`.toString, + "dataLayout" -> "NT", + "labelLayout" -> "N")) val eval = IO.MNISTIter(Map( "image" -> (dataDir + "t10k-images-idx3-ubyte"), @@ -87,7 +89,9 @@ object TrainMnist { "batch_size" -> batchSize.toString, "flat" -> flat, "num_parts" -> kv.numWorkers.toString, - "part_index" -> kv.`rank`.toString)) + "part_index" -> kv.`rank`.toString, + "dataLayout" -> "NT", + "labelLayout" -> "N")) (train, eval) } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala index bb17046b8b2b..2b0a20b40e76 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala @@ -21,9 +21,6 @@ import org.apache.mxnet.Shape import org.apache.mxnet.IO import org.apache.mxnet.DataIter -/** - * @author Depeng Liang - */ object Data { // return train and val iterators for mnist @@ -35,7 +32,9 @@ object Data { "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", "shuffle" -> "True", - "flat" -> flat + "flat" -> flat, + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val trainDataIter = IO.MNISTIter(trainParams) val testParams = Map( @@ -43,7 +42,9 @@ object Data { "label" -> s"$dataPath/t10k-labels-idx1-ubyte", "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", - "flat" -> flat + "flat" -> flat, + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val testDataIter = IO.MNISTIter(testParams) (trainDataIter, testDataIter) From a2b67135f6a2c6c92a436040b33df1cb93277529 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 24 Jul 2018 14:38:53 -0700 Subject: [PATCH 05/25] Gan and Customop fixes --- .../scala/org/apache/mxnetexamples/customop/Data.scala | 9 +++++++-- .../scala/org/apache/mxnetexamples/gan/GanMnist.scala | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala index d61269c131ff..230c56e38678 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala @@ -20,6 +20,7 @@ package org.apache.mxnetexamples.customop import org.apache.mxnet.{DataIter, IO, Shape} object Data { + // return train and val iterators for mnist def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): (DataIter, DataIter) = { val flat = if (inputShape.length == 3) "False" else "True" @@ -29,7 +30,9 @@ object Data { "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", "shuffle" -> "True", - "flat" -> flat + "flat" -> flat, + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val trainDataIter = IO.MNISTIter(trainParams) val testParams = Map( @@ -37,7 +40,9 @@ object Data { "label" -> s"$dataPath/t10k-labels-idx1-ubyte", "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", - "flat" -> flat + "flat" -> flat, + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val testDataIter = IO.MNISTIter(testParams) (trainDataIter, testDataIter) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala index 70846eebfb8e..5ba93276e629 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala @@ -130,7 +130,9 @@ object GanMnist { "label" -> s"$dataPath/train-labels-idx1-ubyte", "input_shape" -> s"(1, 28, 28)", "batch_size" -> s"$batchSize", - "shuffle" -> "True" + "shuffle" -> "True", + "dataLayout" -> "NT", + "labelLayout" -> "N" ) val mnistIter = IO.MNISTIter(params) From ad3f73cd91b9a365760c26ede3738493796be9eb Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 24 Jul 2018 23:38:43 -0700 Subject: [PATCH 06/25] change the DType --- .../src/main/scala/org/apache/mxnet/IO.scala | 31 +++++++++------ .../org/apache/mxnet/io/MXDataIter.scala | 38 ++++++++----------- .../org/apache/mxnet/io/NDArrayIter.scala | 34 ++++++++++------- .../org/apache/mxnet/io/PrefetchingIter.scala | 7 ++-- .../org/apache/mxnet/io/ResizeIter.scala | 4 +- .../apache/mxnetexamples/gan/GanMnist.scala | 2 +- .../multitask/ExampleMultiTask.scala | 6 +-- .../apache/mxnetexamples/rnn/BucketIo.scala | 14 ++++--- .../mxnet/spark/io/LabeledPointIter.scala | 11 +++--- .../mxnet/spark/io/LongLivingDataBatch.scala | 6 ++- .../org/apache/mxnet/spark/io/PointIter.scala | 11 +++--- 11 files changed, 89 insertions(+), 75 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index c3d0b43cd77c..6d90ece4be18 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -25,7 +25,7 @@ import org.slf4j.LoggerFactory import scala.annotation.varargs import scala.collection.immutable.ListMap import scala.collection.mutable.ListBuffer - +import scala.reflect.runtime.universe._ /** * IO iterators for loading training & validation data */ @@ -108,8 +108,12 @@ object IO { val labelName = params.getOrElse("label_name", "label") val dataLayout = params.getOrElse("dataLayout", "NCHW") val labelLayout = params.getOrElse("labelLayout", "N") + val dataDType = params.getOrElse("dataDType", "Float32") + val labelDType = params.getOrElse("labelDType", "Int32") new MXDataIter(out.value, dataName, labelName, - dataLayout = dataLayout, labelLayout = labelLayout) + dataLayout = dataLayout, labelLayout = labelLayout, + dataDType = q"DType ${TermName(dataDType)}".asInstanceOf[DType], + labelDType = q"DType ${TermName(labelDType)}".asInstanceOf[DType]) } // Convert data into canonical form. @@ -144,7 +148,8 @@ class DataBatch(val data: IndexedSeq[NDArray], // (must match the order of input data/label) private val providedData: ListMap[String, Shape] = null, private val providedLabel: ListMap[String, Shape] = null, - val dtype: DType = Base.MX_REAL_TYPE, + val dataDType: DType = Base.MX_REAL_TYPE, + val labelDType: DType = DType.Int32, val dataLayout: String = "NCHW", val labelLayout: String = "N") { /** @@ -178,7 +183,8 @@ object DataBatch { private var pad: Int = 0 private var dataLayout: String = "NCHW" private var labelLayout: String = "N" - private var dtype: DType = Base.MX_REAL_TYPE + private var dataDType: DType = Base.MX_REAL_TYPE + private var labelDType: DType = DType.Int32 private var bucketKey: AnyRef = null private var datatShapes: ListMap[String, Shape] = null private var labelShapes: ListMap[String, Shape] = null @@ -227,11 +233,13 @@ object DataBatch { /** * Set the dtype. - * @param dtype The dtype of the label, default is Float32 + * @param dataDType The dtype of the data, default is Float32 + * @param labelDType The dtype of the label, default is Int32 * @return this */ - def setDType(dtype: DType): Builder = { - this.dtype = dtype + def setDType(dataDType: DType, labelDType: DType): Builder = { + this.dataDType = dataDType + this.labelDType = labelDType this } @@ -290,7 +298,7 @@ object DataBatch { def build(): DataBatch = { require(data != null, "data is required.") new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, - dtype, dataLayout, labelLayout) + dataDType, labelDType, dataLayout, labelLayout) } } } @@ -313,7 +321,8 @@ abstract class DataIter extends Iterator[DataBatch] { @throws(classOf[NoSuchElementException]) def next(): DataBatch = { new DataBatch(getData(), getLabel(), getIndex(), getPad(), - dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2) + dataDType = getDType()._1, labelDType = getDType()._2, + dataLayout = getLayout()._1, labelLayout = getLayout()._2) } /** @@ -337,9 +346,9 @@ abstract class DataIter extends Iterator[DataBatch] { /** * Get the DType - * @return DType of the DataIter + * @return data and label DType of the DataIter */ - def getDType(): DType + def getDType(): (DType, DType) /** * Get the layout diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 5d8fd6d07512..292c8c49b965 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -35,7 +35,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, labelName: String = "label", dtype: DType = DType.Float32, dataLayout: String = "NCHW", - labelLayout: String = "N") + labelLayout: String = "N", + dataDType: DType = DType.Float32, + labelDType: DType = DType.Int32) extends DataIter with WarnIfNotDisposed { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) @@ -45,40 +47,30 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, // fix me if any better way found) private var currentBatch: DataBatch = null - private val (_provideData: ListMap[String, Shape], - _provideLabel: ListMap[String, Shape]) = - if (hasNext) { - iterNext() - val data = currentBatch.data(0) - val label = currentBatch.label(0) - // properties - val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape)) - currentBatch.dispose() - reset() - res - } else { - (null, null) - } - private val (_provideDataDesc: IndexedSeq[DataDesc], _provideLabelDesc: IndexedSeq[DataDesc], + _provideData: ListMap[String, Shape], + _provideLabel: ListMap[String, Shape], _batchSize: Int) = { if (hasNext) { iterNext() val data = currentBatch.data(0) val label = currentBatch.label(0) - val dType = currentBatch.dtype + val dataType = currentBatch.dataDType + val labelDType = currentBatch.labelDType val dataLayout = currentBatch.dataLayout val labelLayout = currentBatch.labelLayout // properties - val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, dataLayout)), - IndexedSeq(new DataDesc(labelName, label.shape, dType, labelLayout)), + val res = (IndexedSeq(new DataDesc(dataName, data.shape, dataDType, dataLayout)), + IndexedSeq(new DataDesc(labelName, label.shape, labelDType, labelLayout)), + ListMap(dataName -> data.shape), + ListMap(labelName -> label.shape), data.shape(0)) currentBatch.dispose() reset() res } else { - (null, null, 0) + (null, null, null, null, 0) } } @@ -130,8 +122,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), index = getIndex(), pad = getPad(), - dtype = getDType(), dataLayout = getLayout()._1, - labelLayout = getLayout()._2) + dataDType = getDType()._1, labelDType = getDType()._2, + dataLayout = getLayout()._1, labelLayout = getLayout()._2) } else { currentBatch = null } @@ -184,7 +176,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, * Get the DType * @return DType */ - def getDType(): DType = dtype + def getDType(): (DType, DType) = (dataDType, labelDType) /** * Get the layout diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 53f7c352c636..4330ce36c130 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -44,7 +44,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], label: IndexedSeq[(String, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, lastBatchHandle: String, - dtype: DType, dataLayout: String, labelLayout: String) extends DataIter { + dataDType: DType, labelDType: DType, + dataLayout: String, labelLayout: String) extends DataIter { /** * @param data Specify the data. Data names will be data_0, data_1, ..., etc. @@ -62,11 +63,11 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label", - dType: DType = MX_REAL_TYPE, dataLayout: String = "NCHW", - labelLayout: String = "N") { + dataDType: DType = MX_REAL_TYPE, labelDType: DType = DType.Int32, + dataLayout: String = "NCHW", labelLayout: String = "N") { this(IO.initData(data, allowEmpty = false, dataName), IO.initData(label, allowEmpty = true, labelName), - dataBatchSize, shuffle, lastBatchHandle, dType, dataLayout, labelLayout) + dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout) } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) @@ -113,8 +114,9 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private val (_provideDataDesc: IndexedSeq[DataDesc], _provideLabelDesc: IndexedSeq[DataDesc]) = { - val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, dataLayout)) - val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, labelLayout)) + val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dataDType, dataLayout)) + val pLabel = initLabel.map(ele => + new DataDesc(ele._1, getShape(ele)._2, labelDType, labelLayout)) (pData, pLabel) } @@ -160,7 +162,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], if (hasNext) { cursor += dataBatchSize new DataBatch(getData(), getLabel(), getIndex(), getPad(), - dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2) + dataDType = getDType()._1, labelDType = getDType()._2, + dataLayout = getLayout()._1, labelLayout = getLayout()._2) } else { throw new NoSuchElementException } @@ -239,8 +242,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], * Get the DType * @return DType */ - def getDType(): DType = { - dtype + def getDType(): (DType, DType) = { + (dataDType, labelDType) } /** @@ -278,7 +281,8 @@ object NDArrayIter { private var lastBatchHandle: String = "pad" private var dataLayout: String = "NCHW" private var labelLayout: String = "N" - private var dtype: DType = Base.MX_REAL_TYPE + private var dataDType: DType = Base.MX_REAL_TYPE + private var labelDType: DType = DType.Int32 /** * Add one data input with its name. @@ -324,11 +328,13 @@ object NDArrayIter { /** * Set the dtype. - * @param dtype The dtype of the label, default is Float32 + * @param dataDType The dtype of the data, default is Float32 + * @param labelDType The dtype of the label, default is Int32 * @return this */ - def setDType(dtype: DType): Builder = { - this.dtype = dtype + def setDType(dataDType: DType, labelDType: DType): Builder = { + this.dataDType = dataDType + this.labelDType = labelDType this } @@ -350,7 +356,7 @@ object NDArrayIter { */ def build(): NDArrayIter = { new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, - dtype, dataLayout, labelLayout) + dataDType, labelDType, dataLayout, labelLayout) } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index 3097948b5de9..2658476402ab 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -181,8 +181,8 @@ class PrefetchingIter( * Get the DType * @return DType */ - def getDType(): DType = { - currentBatch.dtype + def getDType(): (DType, DType) = { + (currentBatch.dataDType, currentBatch.labelDType) } /** @@ -226,7 +226,8 @@ class PrefetchingIter( nextBatch(0).pad, dataLayout = nextBatch(0).dataLayout, labelLayout = nextBatch(0).labelLayout, - dtype = nextBatch(0).dtype) + dataDType = nextBatch(0).dataDType, + labelDType = nextBatch(0).labelDType) for (e <- dataTaken) e.release() true } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index 6e521c219128..f316709a404c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -133,8 +133,8 @@ class ResizeIter( * Get the DType * @return DType */ - def getDType(): DType = { - currentBatch.dtype + def getDType(): (DType, DType) = { + (currentBatch.dataDType, currentBatch.labelDType) } /** diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala index 5ba93276e629..f145c189148e 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala @@ -131,7 +131,7 @@ object GanMnist { "input_shape" -> s"(1, 28, 28)", "batch_size" -> s"$batchSize", "shuffle" -> "True", - "dataLayout" -> "NT", + "dataLayout" -> "NCHW", "labelLayout" -> "N" ) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 3cf46045a538..0e865ab3c0ec 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -67,8 +67,8 @@ object ExampleMultiTask { new DataBatch(batch.data, IndexedSeq(label, label), batch.index, - batch.pad, dtype = batch.dtype, dataLayout = batch.dataLayout, - labelLayout = batch.labelLayout) + batch.pad, dataDType = batch.dataDType, labelDType = batch.labelDType, + dataLayout = batch.dataLayout, labelLayout = batch.labelLayout) } else { throw new NoSuchElementException } @@ -127,7 +127,7 @@ object ExampleMultiTask { */ override def getPad(): Int = this.dataIter.getPad() - override def getDType(): DType = this.dataIter.getDType() + override def getDType(): (DType, DType) = this.dataIter.getDType() override def getLayout(): (String, String) = this.dataIter.getLayout() diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 21cc146a1b32..22688c1db006 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -100,7 +100,8 @@ object BucketIo { readContent: ReadContent = defaultReadContent, dataLayout: String = "NT", labelLayout: String = "N", - dtype : DType = DType.Float32) extends DataIter { + dataDType : DType = DType.Float32, + labelDType: DType = DType.Int32) extends DataIter { private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) @@ -175,12 +176,12 @@ object BucketIo { private val _provideDataDesc = { val tmp = IndexedSeq(new DataDesc("data", - Shape(_batchSize, _defaultBucketKey), dtype, dataLayout)) - tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dtype, dataLayout)) + Shape(_batchSize, _defaultBucketKey), dataDType, dataLayout)) + tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dataDType, dataLayout)) } private val _provideLabelDesc = IndexedSeq(new DataDesc("softmax_label", - Shape(_batchSize, _defaultBucketKey), dtype, labelLayout)) + Shape(_batchSize, _defaultBucketKey), labelDType, labelLayout)) private var iBucket = 0 @@ -212,7 +213,8 @@ object BucketIo { getIndex(), getPad(), this.buckets(bucketIdx).asInstanceOf[AnyRef], - batchProvideData, batchProvideLabel, getDType(), + batchProvideData, batchProvideLabel, + getDType()._1, getDType()._2, getLayout()._1, getLayout()._2) } @@ -251,7 +253,7 @@ object BucketIo { */ override def getPad(): Int = 0 - override def getDType(): DType = dtype + override def getDType(): (DType, DType) = (dataDType, labelDType) override def getLayout(): (String, String) = (dataLayout, labelLayout) diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index 7fda57a4e3e7..fbb8874477f2 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -34,7 +34,8 @@ class LabeledPointIter private[mxnet]( private val _batchSize: Int, private val dataName: String = "data", private val labelName: String = "label", - private val dtype: DType = DType.Float32, + private val dataDType: DType = DType.Float32, + private val labelDType: DType = DType.Int32, private val dataLayout: String = "NCHW", private val labelLayout: String = "N") extends DataIter { @@ -77,7 +78,7 @@ class LabeledPointIter private[mxnet]( val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - dataLayout, labelLayout, dtype) + dataLayout, labelLayout, dataDType, labelDType) cache += dataBatch dataBatch } @@ -130,11 +131,11 @@ class LabeledPointIter private[mxnet]( } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dtype, dataLayout)) + IndexedSeq(new DataDesc(dataName, dataShape, dataDType, dataLayout)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, labelLayout)) + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), labelDType, labelLayout)) } /** @@ -144,7 +145,7 @@ class LabeledPointIter private[mxnet]( */ override def getPad(): Int = 0 - override def getDType(): DType = dtype + override def getDType(): (DType, DType) = (dataDType, labelDType) override def getLayout(): (String, String) = (dataLayout, labelLayout) diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index acbcbc7c2b84..62bec0a5a5d6 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -31,8 +31,10 @@ class LongLivingDataBatch( override val pad: Int, override val dataLayout: String, override val labelLayout: String, - override val dtype: DType) extends DataBatch(data, label, index, pad, - dataLayout = dataLayout, labelLayout = labelLayout, dtype = dtype) { + override val dataDType: DType, + override val labelDType: DType) extends DataBatch(data, label, index, pad, + dataLayout = dataLayout, labelLayout = labelLayout, + dataDType = dataDType, labelDType = labelDType) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index d239e5c641fa..a43906d5e365 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -34,7 +34,8 @@ class PointIter private[mxnet]( private val _batchSize: Int, private val dataName: String = "data", private val labelName: String = "label", - private val dtype: DType = DType.Float32, + private val dataDType: DType = DType.Float32, + private val labelDType: DType = DType.Int32, private val dataLayout: String = "NCHW", private val labelLayout: String = "N") extends DataIter { @@ -76,7 +77,7 @@ class PointIter private[mxnet]( val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - dataLayout, labelLayout, dtype) + dataLayout, labelLayout, dataDType, labelDType) cache += dataBatch dataBatch } @@ -129,11 +130,11 @@ class PointIter private[mxnet]( } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dtype, dataLayout)) + IndexedSeq(new DataDesc(dataName, dataShape, dataDType, dataLayout)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), dtype, labelLayout)) + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), labelDType, labelLayout)) } /** @@ -143,7 +144,7 @@ class PointIter private[mxnet]( */ override def getPad(): Int = 0 - override def getDType(): DType = dtype + override def getDType(): (DType, DType) = (dataDType, labelDType) override def getLayout(): (String, String) = (dataLayout, labelLayout) From 559ed96bc14534d2c86fb540336bf2893315414f Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 25 Jul 2018 10:32:24 -0700 Subject: [PATCH 07/25] add one more class to convert Strings to DTypes --- .../core/src/main/scala/org/apache/mxnet/DType.scala | 9 +++++++++ .../core/src/main/scala/org/apache/mxnet/IO.scala | 5 ++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala index 4458a7c7aeb8..b015bd2169b7 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala @@ -35,4 +35,13 @@ object DType extends Enumeration { case DType.Unknown => 0 } } + private[mxnet] def getType(dtypeStr: String): DType = { + dtypeStr match { + case "UInt8" => DType.UInt8 + case "Int32" => DType.Int32 + case "Float16" => DType.Float16 + case "Float32" => DType.Float32 + case "Float64" => DType.Float64 + } + } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 6d90ece4be18..56cd59a9c24a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory import scala.annotation.varargs import scala.collection.immutable.ListMap import scala.collection.mutable.ListBuffer -import scala.reflect.runtime.universe._ /** * IO iterators for loading training & validation data */ @@ -112,8 +111,8 @@ object IO { val labelDType = params.getOrElse("labelDType", "Int32") new MXDataIter(out.value, dataName, labelName, dataLayout = dataLayout, labelLayout = labelLayout, - dataDType = q"DType ${TermName(dataDType)}".asInstanceOf[DType], - labelDType = q"DType ${TermName(labelDType)}".asInstanceOf[DType]) + dataDType = DType.getType(dataDType), + labelDType = DType.getType(labelDType)) } // Convert data into canonical form. From 33acc7b3418e64c58b8cae139b5ca4a9e8db909c Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 27 Jul 2018 11:29:40 -0700 Subject: [PATCH 08/25] convert layout to global --- .../src/main/scala/org/apache/mxnet/IO.scala | 23 +++++---- .../main/scala/org/apache/mxnet/Layout.scala | 48 +++++++++++++++++++ .../org/apache/mxnet/io/MXDataIter.scala | 10 ++-- .../org/apache/mxnet/io/NDArrayIter.scala | 46 ++++++++++-------- .../org/apache/mxnet/io/PrefetchingIter.scala | 3 +- .../org/apache/mxnet/io/ResizeIter.scala | 3 +- .../module/DataParallelExecutorGroup.scala | 2 +- .../test/scala/org/apache/mxnet/IOSuite.scala | 6 +-- .../scala/org/apache/mxnet/ModuleSuite.scala | 32 ++++++------- .../ImageClassifierExample.scala | 7 ++- .../objectdetector/SSDClassifierExample.scala | 6 +-- .../multitask/ExampleMultiTask.scala | 4 +- .../apache/mxnetexamples/rnn/BucketIo.scala | 7 +-- .../apache/mxnet/infer/ImageClassifier.scala | 5 +- .../org/apache/mxnet/infer/Predictor.scala | 10 ++-- .../mxnet/infer/ImageClassifierSuite.scala | 9 ++-- .../apache/mxnet/infer/PredictorSuite.scala | 6 +-- .../mxnet/spark/io/LabeledPointIter.scala | 7 +-- .../mxnet/spark/io/LongLivingDataBatch.scala | 5 +- .../org/apache/mxnet/spark/io/PointIter.scala | 7 +-- 20 files changed, 153 insertions(+), 93 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 56cd59a9c24a..928606db8b70 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -19,6 +19,7 @@ package org.apache.mxnet import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import org.apache.mxnet.io.{MXDataIter, MXDataPack} import org.slf4j.LoggerFactory @@ -110,7 +111,8 @@ object IO { val dataDType = params.getOrElse("dataDType", "Float32") val labelDType = params.getOrElse("labelDType", "Int32") new MXDataIter(out.value, dataName, labelName, - dataLayout = dataLayout, labelLayout = labelLayout, + dataLayout = Layout.getLayout(dataLayout), + labelLayout = Layout.getLayout(labelLayout), dataDType = DType.getType(dataDType), labelDType = DType.getType(labelDType)) } @@ -149,8 +151,8 @@ class DataBatch(val data: IndexedSeq[NDArray], private val providedLabel: ListMap[String, Shape] = null, val dataDType: DType = Base.MX_REAL_TYPE, val labelDType: DType = DType.Int32, - val dataLayout: String = "NCHW", - val labelLayout: String = "N") { + val dataLayout: Layout = Layout.NCHW, + val labelLayout: Layout = Layout.N) { /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -180,8 +182,8 @@ object DataBatch { private var label: IndexedSeq[NDArray] = null private var index: IndexedSeq[Long] = null private var pad: Int = 0 - private var dataLayout: String = "NCHW" - private var labelLayout: String = "N" + private var dataLayout: Layout = Layout.NCHW + private var labelLayout: Layout = Layout.N private var dataDType: DType = Base.MX_REAL_TYPE private var labelDType: DType = DType.Int32 private var bucketKey: AnyRef = null @@ -248,7 +250,7 @@ object DataBatch { * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(dataLayout: String, labelLayout: String): Builder = { + def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = { this.dataLayout = dataLayout this.labelLayout = labelLayout this @@ -353,7 +355,7 @@ abstract class DataIter extends Iterator[DataBatch] { * Get the layout * @return data and label layout of the DataIter */ - def getLayout(): (String, String) + def getLayout(): (Layout, Layout) /** * Get the index of current batch @@ -393,10 +395,11 @@ abstract class DataPack() extends Iterable[DataBatch] { // Named data desc description contains name, shape, type and other extended attributes. case class DataDesc(name: String, shape: Shape, - dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") { - require(shape.length == layout.length, ("number of dimensions in shape :%d with" + + dtype: DType = Base.MX_REAL_TYPE, layout: Layout = Layout.NCHW) { + val layoutStr = layout.toString + require(shape.length == layoutStr.length, ("number of dimensions in shape :%d with" + " shape: %s should match the length of the layout: %d with layout: %s"). - format(shape.length, shape.toString, layout.length, layout)) + format(shape.length, shape.toString, layoutStr.length, layoutStr)) override def toString(): String = { s"DataDesc[$name,$shape,$dtype,$layout]" diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala new file mode 100644 index 000000000000..069d2ae03e82 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +/** + * Layout type that represent what inside of a shape + * N Batch size + * C number of channels + * H height (image) + * W width (image) + * T temporal axis representing time (NLP) + */ + +object Layout extends Enumeration { + type Layout = Value + val NCHW = Value("NCHW") + val TNC = Value("TNC") + val CHW = Value("CHW") + val NT = Value("NT") + val N = Value("N") + + private[mxnet] def getLayout(layoutStr: String): Layout = { + layoutStr match { + case "NCHW" => NCHW + case "TNC" => TNC + case "CHW" => CHW + case "NT" => NT + case "N" => N + case _ => throw new RuntimeException( + s"Unknown $layoutStr defined!, please check Layout.scala") + } + } +} \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 292c8c49b965..0bc61241aa64 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -21,6 +21,7 @@ import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType import org.apache.mxnet._ import org.apache.mxnet.IO._ +import org.apache.mxnet.Layout.Layout import org.slf4j.LoggerFactory import scala.collection.immutable.ListMap @@ -33,9 +34,8 @@ import scala.collection.mutable.ListBuffer private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, dataName: String = "data", labelName: String = "label", - dtype: DType = DType.Float32, - dataLayout: String = "NCHW", - labelLayout: String = "N", + dataLayout: Layout = Layout.NCHW, + labelLayout: Layout = Layout.N, dataDType: DType = DType.Float32, labelDType: DType = DType.Int32) extends DataIter with WarnIfNotDisposed { @@ -56,7 +56,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, iterNext() val data = currentBatch.data(0) val label = currentBatch.label(0) - val dataType = currentBatch.dataDType + val dataDType = currentBatch.dataDType val labelDType = currentBatch.labelDType val dataLayout = currentBatch.dataLayout val labelLayout = currentBatch.labelLayout @@ -182,7 +182,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, * Get the layout * @return layout */ - def getLayout(): (String, String) = (dataLayout, labelLayout) + def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 4330ce36c130..e0ac7e6d5265 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -21,6 +21,7 @@ import java.util.NoSuchElementException import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -45,31 +46,34 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, lastBatchHandle: String, dataDType: DType, labelDType: DType, - dataLayout: String, labelLayout: String) extends DataIter { - + dataLayout: Layout, labelLayout: Layout) extends DataIter { +// scalastyle:off /** - * @param data Specify the data. Data names will be data_0, data_1, ..., etc. - * @param label Same as data, but is not fed to the model during testing. - * Label names will be label_0, label_1, ..., etc. - * @param dataBatchSize Batch Size - * @param shuffle Whether to shuffle the data - * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch - * - * This iterator will pad, discard or roll over the last batch if - * the size of data does not match batch_size. Roll over is intended - * for training and can cause problems if used for prediction. - */ + * @param data Specify the data. Data names will be data_0, data_1, ..., etc. + * @param label Same as data, but is not fed to the model during testing. + * Label names will be label_0, label_1, ..., etc. + * @param dataBatchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label", - dataDType: DType = MX_REAL_TYPE, labelDType: DType = DType.Int32, - dataLayout: String = "NCHW", labelLayout: String = "N") { + dataDType: DType = Base.MX_REAL_TYPE, + labelDType: DType = DType.Int32, + dataLayout: Layout = Layout.NCHW, + labelLayout : Layout = Layout.N) { this(IO.initData(data, allowEmpty = false, dataName), IO.initData(label, allowEmpty = true, labelName), - dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout) + dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, + dataLayout, labelLayout) } - +// scalastyle:on private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = { @@ -250,7 +254,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], * Get the layout * @return layout */ - def getLayout(): (String, String) = { + def getLayout(): (Layout, Layout) = { (dataLayout, labelLayout) } @@ -279,8 +283,8 @@ object NDArrayIter { private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" - private var dataLayout: String = "NCHW" - private var labelLayout: String = "N" + private var dataLayout: Layout = Layout.NCHW + private var labelLayout: Layout = Layout.N private var dataDType: DType = Base.MX_REAL_TYPE private var labelDType: DType = DType.Int32 @@ -344,7 +348,7 @@ object NDArrayIter { * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(dataLayout: String, labelLayout: String): Builder = { + def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = { this.dataLayout = dataLayout this.labelLayout = labelLayout this diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index 2658476402ab..bcfb1d043272 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -22,6 +22,7 @@ import org.slf4j.LoggerFactory import java.util.concurrent.Semaphore import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import scala.collection.immutable.ListMap @@ -189,7 +190,7 @@ class PrefetchingIter( * Get the layout * @return layout */ - def getLayout(): (String, String) = { + def getLayout(): (Layout, Layout) = { (currentBatch.dataLayout, currentBatch.labelLayout) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index f316709a404c..5de42290154a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -20,6 +20,7 @@ package org.apache.mxnet.io import java.util.NoSuchElementException import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -141,7 +142,7 @@ class ResizeIter( * Get the layout * @return layout */ - def getLayout(): (String, String) = { + def getLayout(): (Layout, Layout) = { (currentBatch.dataLayout, currentBatch.labelLayout) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala index 1494dc84035c..bb497e47d139 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala @@ -330,7 +330,7 @@ class DataParallelExecutorGroup private[module]( */ private def decideSlices(dataShapes: Seq[DataDesc]): Seq[Int] = { require(dataShapes.size > 0) - val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layout))) + val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layoutStr))) for ((dataDesc, axis) <- dataShapes.zip(majorAxis)) { if (axis != -1) { diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 478f834d210b..919c94b4b819 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -252,7 +252,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test pad val dataIter0 = new NDArrayIter(data, label, 128, false, "pad", - dataLayout = "NTC", labelLayout = "NT") + dataLayout = Layout.TNC, labelLayout = Layout.NT) var batchCount = 0 val nBatch0 = 8 while(dataIter0.hasNext) { @@ -271,7 +271,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { .addData("data0", data(0)).addData("data1", data(1)) .addLabel("label", label(0)) .setBatchSize(128) - .setLayout("NTC", "NT") + .setLayout(Layout.TNC, Layout.NT) .setLastBatchHandle("discard").build() val nBatch1 = 7 batchCount = 0 @@ -288,7 +288,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test empty label (for prediction) val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard", - dataLayout = "NTC") + dataLayout = Layout.TNC) batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index a73195f77207..10c547c5a9b2 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -33,7 +33,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, "TNC"))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, Layout.TNC))) mod.initParams() mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape, dtype = dType)), @@ -57,9 +57,9 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { .setContext(Context.cpu(0), Context.cpu(1)) .build() mod.bind(dataShapes = IndexedSeq( - DataDesc("b", Shape(5, 5), layout = "NT"), - DataDesc("c", Shape(5, 5), layout = "NT"), - DataDesc("a", Shape(5, 5), layout = "NT")), + DataDesc("b", Shape(5, 5), layout = Layout.NT), + DataDesc("c", Shape(5, 5), layout = Layout.NT), + DataDesc("a", Shape(5, 5), layout = Layout.NT)), inputsNeedGrad = true ) mod.initParams() @@ -87,7 +87,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val dShape = Shape(3, 8, 7) val mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "TNC"))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = Layout.TNC))) mod.initParams() mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape)), @@ -110,14 +110,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // single device var mod = new Module(sym, IndexedSeq("data"), null) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) mod.update() mod.saveCheckpoint("test", 0, saveOptStates = true) var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true) - mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) + mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) assert(mod.getSymbol.toJson == mod2.getSymbol.toJson) mapEqu(mod.getParams._1, mod2.getParams._1) @@ -125,14 +125,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // multi device mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT" ))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) mod.update() mod.saveCheckpoint("test", 0, saveOptStates = true) mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true) - mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) + mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) assert(mod.getSymbol.toJson == mod2.getSymbol.toJson) mapEqu(mod.getParams._1, mod2.getParams._1) @@ -145,7 +145,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { var dShape = Shape(7, 20) val mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "NT"))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = Layout.NT))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 1f)) @@ -159,7 +159,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // reshape module dShape = Shape(14, 20) - mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT"))) + mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = Layout.NT))) mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape)), label = null, index = null, pad = 0)) @@ -170,7 +170,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // return to original binded shape dShape = Shape(7, 20) - mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT"))) + mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = Layout.NT))) mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape)), label = null, index = null, pad = 0)) @@ -185,7 +185,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", - dataLayout = "NCHW", labelLayout = "NCHW") + dataLayout = Layout.NCHW, labelLayout = Layout.NCHW) // symbols var x = Symbol.Variable("data") @@ -236,7 +236,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", - dataLayout = "NCHW", labelLayout = "NCHW") + dataLayout = Layout.NCHW, labelLayout = Layout.NCHW) // symbols var x = Symbol.Variable("data") @@ -311,8 +311,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val mod = new Module(sym, IndexedSeq("data1", "data2")) mod.bind(dataShapes = IndexedSeq( - DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = "NCHW")), - labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = "N"))) + DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = Layout.NCHW)), + labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = Layout.N))) ) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f)) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala index 3bbd780d39b9..3bae4a703e4c 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala @@ -17,10 +17,9 @@ package org.apache.mxnetexamples.infer.imageclassifier -import org.apache.mxnet.Shape +import org.apache.mxnet._ import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.LoggerFactory -import org.apache.mxnet.{DType, DataDesc, Context} import org.apache.mxnet.infer.ImageClassifier import scala.collection.JavaConverters._ @@ -46,7 +45,7 @@ object ImageClassifierExample { val dType = DType.Float32 val inputShape = Shape(1, 3, 224, 224) - val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) + val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) // Create object of ImageClassifier class val imgClassifier: ImageClassifier = new @@ -67,7 +66,7 @@ object ImageClassifierExample { val dType = DType.Float32 val inputShape = Shape(1, 3, 224, 224) - val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) + val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) // Create object of ImageClassifier class val imgClassifier: ImageClassifier = new diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala index c9707cb3ff6f..2b8e49b8042f 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala @@ -19,7 +19,7 @@ package org.apache.mxnetexamples.infer.objectdetector import java.io.File -import org.apache.mxnet.{Context, DType, DataDesc, Shape} +import org.apache.mxnet._ import org.apache.mxnet.infer._ import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.LoggerFactory @@ -58,7 +58,7 @@ object SSDClassifierExample { val inputShape = Shape(1, 3, 512, 512) // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...]) val outputShape = Shape(1, 6132, 6) - val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) + val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) val img = ImageClassifier.loadImageFromFile(inputImagePath) val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context) val output = objDetector.imageObjectDetect(img, Some(3)) @@ -73,7 +73,7 @@ object SSDClassifierExample { val inputShape = Shape(1, 3, 512, 512) // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...]) val outputShape = Shape(1, 6132, 6) - val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) + val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context) // Loading batch of images from the directory path val batchFiles = generateBatches(inputImageDir, 20) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 0e865ab3c0ec..6410d7781295 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -25,12 +25,12 @@ import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ import org.apache.commons.io.FileUtils - import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, NDArray, Shape, Symbol, Xavier} import org.apache.mxnet.DType.DType import org.apache.mxnet.optimizer.RMSProp import org.apache.mxnet.Executor import org.apache.mxnetexamples.Util +import org.apache.mxnet.Layout.Layout import scala.collection.immutable.ListMap import scala.sys.process.Process @@ -129,7 +129,7 @@ object ExampleMultiTask { override def getDType(): (DType, DType) = this.dataIter.getDType() - override def getLayout(): (String, String) = this.dataIter.getLayout() + override def getLayout(): (Layout, Layout) = this.dataIter.getLayout() // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this.dataIter.provideData diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 22688c1db006..2bf9654802e6 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -19,6 +19,7 @@ package org.apache.mxnetexamples.rnn import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -98,8 +99,8 @@ object BucketIo { _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, readContent: ReadContent = defaultReadContent, - dataLayout: String = "NT", - labelLayout: String = "N", + dataLayout: Layout = Layout.NT, + labelLayout: Layout = Layout.N, dataDType : DType = DType.Float32, labelDType: DType = DType.Int32) extends DataIter { @@ -255,7 +256,7 @@ object BucketIo { override def getDType(): (DType, DType) = (dataDType, labelDType) - override def getLayout(): (String, String) = (dataLayout, labelLayout) + override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = this._provideLabel diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala index 8d31d1f6b3d6..db5923efe5c5 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala @@ -49,10 +49,11 @@ class ImageClassifier(modelPathPrefix: String, extends Classifier(modelPathPrefix, inputDescriptors, contexts, epoch) { - protected[infer] val inputLayout = inputDescriptors.head.layout + protected[infer] val inputLayout = inputDescriptors.head.layout.toString require(inputDescriptors.nonEmpty, "Please provide input descriptor") - require(inputDescriptors.head.layout == "NCHW", "Provided layout doesn't match NCHW format") + require(inputDescriptors.head.layout.toString == "NCHW", + "Provided layout doesn't match NCHW format") protected[infer] val inputShape = inputDescriptors.head.shape diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala index 2a4f03056372..75b55209b1d2 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala @@ -18,7 +18,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter -import org.apache.mxnet.{Context, DataDesc, NDArray, Shape} +import org.apache.mxnet._ import org.apache.mxnet.module.Module import scala.collection.mutable.ListBuffer @@ -76,15 +76,15 @@ class Predictor(modelPathPrefix: String, private val logger = LoggerFactory.getLogger(classOf[Predictor]) - require(inputDescriptors.head.layout.size != 0, "layout size should not be zero") + require(inputDescriptors.head.layout.toString.size != 0, "layout size should not be zero") - protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N') + protected[infer] var batchIndex = inputDescriptors(0).layout.toString.indexOf('N') protected[infer] var batchSize = if (batchIndex != -1) inputDescriptors(0).shape(batchIndex) else 1 protected[infer] var iDescriptors = inputDescriptors - inputDescriptors.foreach((f: DataDesc) => require(f.layout.indexOf('N') == batchIndex, + inputDescriptors.foreach((f: DataDesc) => require(f.layout.toString.indexOf('N') == batchIndex, "batch size should be in the same index for all inputs")) if (batchIndex != -1) { @@ -94,7 +94,7 @@ class Predictor(modelPathPrefix: String, // Note: this is assuming that the input needs a batch logger.warn("InputDescriptor does not have batchSize, using 1 as the default batchSize") iDescriptors = inputDescriptors.map((f: DataDesc) => new DataDesc(f.name, - Shape(1 +: f.shape.toVector), f.dtype, 'N' +: f.layout)) + Shape(1 +: f.shape.toVector), f.dtype, Layout.getLayout('N' +: f.layout.toString))) batchIndex = 1 } diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala index 948764ee8044..1eb3bd91e37e 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.mxnet.infer -import org.apache.mxnet.{DType, DataDesc, Shape, NDArray, Context} - +import org.apache.mxnet._ import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.BeforeAndAfterAll @@ -60,7 +59,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { test("ImageClassifierSuite-testConvertBufferedImageToNDArray") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 2, 2), - dType, "NCHW")) + dType, Layout.NCHW)) val image1 = new BufferedImage(100, 200, BufferedImage.TYPE_BYTE_GRAY) val image2 = ImageClassifier.reshapeImage(image1, 2, 2) @@ -73,7 +72,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { test("ImageClassifierSuite-testWithInputImage") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), - dType, "NCHW")) + dType, Layout.NCHW)) val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB) @@ -111,7 +110,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { test("ImageClassifierSuite-testWithInputBatchImage") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), - dType, "NCHW")) + dType, Layout.NCHW)) val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB) val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage) diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala index 53fd7f310689..cdd5146d0a03 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala @@ -19,7 +19,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.module.{BaseModule, Module} -import org.apache.mxnet.{DataDesc, NDArray, Shape} +import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape} import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -45,7 +45,7 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { val mockPredictor = new MyPredictor("xyz", inputDescriptor) assert(mockPredictor.getBatchSize == 1) - assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N')) + assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.toString.indexOf('N')) val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)), new DataDesc("data", Shape(2, 3, 2, 2))) @@ -55,7 +55,7 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { } // batchsize is defaulted to 1 - val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW")) + val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = Layout.CHW)) val p2 = new MyPredictor("xyz", inputDescriptor) assert(p2.getBatchSize == 1, "should use a default batch size of 1") diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index fbb8874477f2..b84c91807336 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -18,6 +18,7 @@ package org.apache.mxnet.spark.io import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.apache.spark.mllib.regression.LabeledPoint @@ -36,8 +37,8 @@ class LabeledPointIter private[mxnet]( private val labelName: String = "label", private val dataDType: DType = DType.Float32, private val labelDType: DType = DType.Int32, - private val dataLayout: String = "NCHW", - private val labelLayout: String = "N") extends DataIter { + private val dataLayout: Layout = Layout.NCHW, + private val labelLayout: Layout = Layout.N) extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -147,7 +148,7 @@ class LabeledPointIter private[mxnet]( override def getDType(): (DType, DType) = (dataDType, labelDType) - override def getLayout(): (String, String) = (dataLayout, labelLayout) + override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index 62bec0a5a5d6..0d5068544adc 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -18,6 +18,7 @@ package org.apache.mxnet.spark.io import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import org.apache.mxnet.{DataBatch, NDArray} /** @@ -29,8 +30,8 @@ class LongLivingDataBatch( override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], override val pad: Int, - override val dataLayout: String, - override val labelLayout: String, + override val dataLayout: Layout, + override val labelLayout: Layout, override val dataDType: DType, override val labelDType: DType) extends DataBatch(data, label, index, pad, dataLayout = dataLayout, labelLayout = labelLayout, diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index a43906d5e365..96eacdab9ab0 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -18,6 +18,7 @@ package org.apache.mxnet.spark.io import org.apache.mxnet.DType.DType +import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.apache.spark.mllib.linalg.Vector @@ -36,8 +37,8 @@ class PointIter private[mxnet]( private val labelName: String = "label", private val dataDType: DType = DType.Float32, private val labelDType: DType = DType.Int32, - private val dataLayout: String = "NCHW", - private val labelLayout: String = "N") extends DataIter { + private val dataLayout: Layout = Layout.NCHW, + private val labelLayout: Layout = Layout.N) extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -146,7 +147,7 @@ class PointIter private[mxnet]( override def getDType(): (DType, DType) = (dataDType, labelDType) - override def getLayout(): (String, String) = (dataLayout, labelLayout) + override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize From ff34f96d305e519b2c3f4056cf0ce0a42665ae63 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 27 Jul 2018 13:02:31 -0700 Subject: [PATCH 09/25] scala style fix --- scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala index 069d2ae03e82..510cf88b59fa 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala @@ -45,4 +45,4 @@ object Layout extends Enumeration { s"Unknown $layoutStr defined!, please check Layout.scala") } } -} \ No newline at end of file +} From 1252d1f1398109c91706eeb066be2bd2a2ea25ab Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 1 Aug 2018 11:40:23 -0700 Subject: [PATCH 10/25] Revert to 8c7d1f8 --- .../src/main/scala/org/apache/mxnet/IO.scala | 23 ++++----- .../main/scala/org/apache/mxnet/Layout.scala | 48 ------------------- .../org/apache/mxnet/io/MXDataIter.scala | 10 ++-- .../org/apache/mxnet/io/NDArrayIter.scala | 46 ++++++++---------- .../org/apache/mxnet/io/PrefetchingIter.scala | 3 +- .../org/apache/mxnet/io/ResizeIter.scala | 3 +- .../module/DataParallelExecutorGroup.scala | 2 +- .../test/scala/org/apache/mxnet/IOSuite.scala | 6 +-- .../scala/org/apache/mxnet/ModuleSuite.scala | 32 ++++++------- .../ImageClassifierExample.scala | 7 +-- .../objectdetector/SSDClassifierExample.scala | 6 +-- .../multitask/ExampleMultiTask.scala | 4 +- .../apache/mxnetexamples/rnn/BucketIo.scala | 7 ++- .../apache/mxnet/infer/ImageClassifier.scala | 5 +- .../org/apache/mxnet/infer/Predictor.scala | 10 ++-- .../mxnet/infer/ImageClassifierSuite.scala | 9 ++-- .../apache/mxnet/infer/PredictorSuite.scala | 6 +-- .../mxnet/spark/io/LabeledPointIter.scala | 7 ++- .../mxnet/spark/io/LongLivingDataBatch.scala | 5 +- .../org/apache/mxnet/spark/io/PointIter.scala | 7 ++- 20 files changed, 93 insertions(+), 153 deletions(-) delete mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 928606db8b70..56cd59a9c24a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -19,7 +19,6 @@ package org.apache.mxnet import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import org.apache.mxnet.io.{MXDataIter, MXDataPack} import org.slf4j.LoggerFactory @@ -111,8 +110,7 @@ object IO { val dataDType = params.getOrElse("dataDType", "Float32") val labelDType = params.getOrElse("labelDType", "Int32") new MXDataIter(out.value, dataName, labelName, - dataLayout = Layout.getLayout(dataLayout), - labelLayout = Layout.getLayout(labelLayout), + dataLayout = dataLayout, labelLayout = labelLayout, dataDType = DType.getType(dataDType), labelDType = DType.getType(labelDType)) } @@ -151,8 +149,8 @@ class DataBatch(val data: IndexedSeq[NDArray], private val providedLabel: ListMap[String, Shape] = null, val dataDType: DType = Base.MX_REAL_TYPE, val labelDType: DType = DType.Int32, - val dataLayout: Layout = Layout.NCHW, - val labelLayout: Layout = Layout.N) { + val dataLayout: String = "NCHW", + val labelLayout: String = "N") { /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -182,8 +180,8 @@ object DataBatch { private var label: IndexedSeq[NDArray] = null private var index: IndexedSeq[Long] = null private var pad: Int = 0 - private var dataLayout: Layout = Layout.NCHW - private var labelLayout: Layout = Layout.N + private var dataLayout: String = "NCHW" + private var labelLayout: String = "N" private var dataDType: DType = Base.MX_REAL_TYPE private var labelDType: DType = DType.Int32 private var bucketKey: AnyRef = null @@ -250,7 +248,7 @@ object DataBatch { * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = { + def setLayout(dataLayout: String, labelLayout: String): Builder = { this.dataLayout = dataLayout this.labelLayout = labelLayout this @@ -355,7 +353,7 @@ abstract class DataIter extends Iterator[DataBatch] { * Get the layout * @return data and label layout of the DataIter */ - def getLayout(): (Layout, Layout) + def getLayout(): (String, String) /** * Get the index of current batch @@ -395,11 +393,10 @@ abstract class DataPack() extends Iterable[DataBatch] { // Named data desc description contains name, shape, type and other extended attributes. case class DataDesc(name: String, shape: Shape, - dtype: DType = Base.MX_REAL_TYPE, layout: Layout = Layout.NCHW) { - val layoutStr = layout.toString - require(shape.length == layoutStr.length, ("number of dimensions in shape :%d with" + + dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") { + require(shape.length == layout.length, ("number of dimensions in shape :%d with" + " shape: %s should match the length of the layout: %d with layout: %s"). - format(shape.length, shape.toString, layoutStr.length, layoutStr)) + format(shape.length, shape.toString, layout.length, layout)) override def toString(): String = { s"DataDesc[$name,$shape,$dtype,$layout]" diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala deleted file mode 100644 index 510cf88b59fa..000000000000 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet - -/** - * Layout type that represent what inside of a shape - * N Batch size - * C number of channels - * H height (image) - * W width (image) - * T temporal axis representing time (NLP) - */ - -object Layout extends Enumeration { - type Layout = Value - val NCHW = Value("NCHW") - val TNC = Value("TNC") - val CHW = Value("CHW") - val NT = Value("NT") - val N = Value("N") - - private[mxnet] def getLayout(layoutStr: String): Layout = { - layoutStr match { - case "NCHW" => NCHW - case "TNC" => TNC - case "CHW" => CHW - case "NT" => NT - case "N" => N - case _ => throw new RuntimeException( - s"Unknown $layoutStr defined!, please check Layout.scala") - } - } -} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 0bc61241aa64..292c8c49b965 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -21,7 +21,6 @@ import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType import org.apache.mxnet._ import org.apache.mxnet.IO._ -import org.apache.mxnet.Layout.Layout import org.slf4j.LoggerFactory import scala.collection.immutable.ListMap @@ -34,8 +33,9 @@ import scala.collection.mutable.ListBuffer private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, dataName: String = "data", labelName: String = "label", - dataLayout: Layout = Layout.NCHW, - labelLayout: Layout = Layout.N, + dtype: DType = DType.Float32, + dataLayout: String = "NCHW", + labelLayout: String = "N", dataDType: DType = DType.Float32, labelDType: DType = DType.Int32) extends DataIter with WarnIfNotDisposed { @@ -56,7 +56,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, iterNext() val data = currentBatch.data(0) val label = currentBatch.label(0) - val dataDType = currentBatch.dataDType + val dataType = currentBatch.dataDType val labelDType = currentBatch.labelDType val dataLayout = currentBatch.dataLayout val labelLayout = currentBatch.labelLayout @@ -182,7 +182,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, * Get the layout * @return layout */ - def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) + def getLayout(): (String, String) = (dataLayout, labelLayout) // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index e0ac7e6d5265..4330ce36c130 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -21,7 +21,6 @@ import java.util.NoSuchElementException import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -46,34 +45,31 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, lastBatchHandle: String, dataDType: DType, labelDType: DType, - dataLayout: Layout, labelLayout: Layout) extends DataIter { -// scalastyle:off + dataLayout: String, labelLayout: String) extends DataIter { + /** - * @param data Specify the data. Data names will be data_0, data_1, ..., etc. - * @param label Same as data, but is not fed to the model during testing. - * Label names will be label_0, label_1, ..., etc. - * @param dataBatchSize Batch Size - * @param shuffle Whether to shuffle the data - * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch - * - * This iterator will pad, discard or roll over the last batch if - * the size of data does not match batch_size. Roll over is intended - * for training and can cause problems if used for prediction. - */ + * @param data Specify the data. Data names will be data_0, data_1, ..., etc. + * @param label Same as data, but is not fed to the model during testing. + * Label names will be label_0, label_1, ..., etc. + * @param dataBatchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label", - dataDType: DType = Base.MX_REAL_TYPE, - labelDType: DType = DType.Int32, - dataLayout: Layout = Layout.NCHW, - labelLayout : Layout = Layout.N) { + dataDType: DType = MX_REAL_TYPE, labelDType: DType = DType.Int32, + dataLayout: String = "NCHW", labelLayout: String = "N") { this(IO.initData(data, allowEmpty = false, dataName), IO.initData(label, allowEmpty = true, labelName), - dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, - dataLayout, labelLayout) + dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout) } -// scalastyle:on + private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = { @@ -254,7 +250,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], * Get the layout * @return layout */ - def getLayout(): (Layout, Layout) = { + def getLayout(): (String, String) = { (dataLayout, labelLayout) } @@ -283,8 +279,8 @@ object NDArrayIter { private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" - private var dataLayout: Layout = Layout.NCHW - private var labelLayout: Layout = Layout.N + private var dataLayout: String = "NCHW" + private var labelLayout: String = "N" private var dataDType: DType = Base.MX_REAL_TYPE private var labelDType: DType = DType.Int32 @@ -348,7 +344,7 @@ object NDArrayIter { * @param labelLayout The layout of the label, default is N * @return this */ - def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = { + def setLayout(dataLayout: String, labelLayout: String): Builder = { this.dataLayout = dataLayout this.labelLayout = labelLayout this diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index bcfb1d043272..2658476402ab 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -22,7 +22,6 @@ import org.slf4j.LoggerFactory import java.util.concurrent.Semaphore import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import scala.collection.immutable.ListMap @@ -190,7 +189,7 @@ class PrefetchingIter( * Get the layout * @return layout */ - def getLayout(): (Layout, Layout) = { + def getLayout(): (String, String) = { (currentBatch.dataLayout, currentBatch.labelLayout) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index 5de42290154a..f316709a404c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -20,7 +20,6 @@ package org.apache.mxnet.io import java.util.NoSuchElementException import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -142,7 +141,7 @@ class ResizeIter( * Get the layout * @return layout */ - def getLayout(): (Layout, Layout) = { + def getLayout(): (String, String) = { (currentBatch.dataLayout, currentBatch.labelLayout) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala index bb497e47d139..1494dc84035c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala @@ -330,7 +330,7 @@ class DataParallelExecutorGroup private[module]( */ private def decideSlices(dataShapes: Seq[DataDesc]): Seq[Int] = { require(dataShapes.size > 0) - val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layoutStr))) + val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layout))) for ((dataDesc, axis) <- dataShapes.zip(majorAxis)) { if (axis != -1) { diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 919c94b4b819..478f834d210b 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -252,7 +252,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test pad val dataIter0 = new NDArrayIter(data, label, 128, false, "pad", - dataLayout = Layout.TNC, labelLayout = Layout.NT) + dataLayout = "NTC", labelLayout = "NT") var batchCount = 0 val nBatch0 = 8 while(dataIter0.hasNext) { @@ -271,7 +271,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { .addData("data0", data(0)).addData("data1", data(1)) .addLabel("label", label(0)) .setBatchSize(128) - .setLayout(Layout.TNC, Layout.NT) + .setLayout("NTC", "NT") .setLastBatchHandle("discard").build() val nBatch1 = 7 batchCount = 0 @@ -288,7 +288,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test empty label (for prediction) val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard", - dataLayout = Layout.TNC) + dataLayout = "NTC") batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 10c547c5a9b2..a73195f77207 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -33,7 +33,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, Layout.TNC))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, "TNC"))) mod.initParams() mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape, dtype = dType)), @@ -57,9 +57,9 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { .setContext(Context.cpu(0), Context.cpu(1)) .build() mod.bind(dataShapes = IndexedSeq( - DataDesc("b", Shape(5, 5), layout = Layout.NT), - DataDesc("c", Shape(5, 5), layout = Layout.NT), - DataDesc("a", Shape(5, 5), layout = Layout.NT)), + DataDesc("b", Shape(5, 5), layout = "NT"), + DataDesc("c", Shape(5, 5), layout = "NT"), + DataDesc("a", Shape(5, 5), layout = "NT")), inputsNeedGrad = true ) mod.initParams() @@ -87,7 +87,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val dShape = Shape(3, 8, 7) val mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = Layout.TNC))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "TNC"))) mod.initParams() mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape)), @@ -110,14 +110,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // single device var mod = new Module(sym, IndexedSeq("data"), null) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) mod.update() mod.saveCheckpoint("test", 0, saveOptStates = true) var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true) - mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) + mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) assert(mod.getSymbol.toJson == mod2.getSymbol.toJson) mapEqu(mod.getParams._1, mod2.getParams._1) @@ -125,14 +125,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // multi device mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT" ))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) mod.update() mod.saveCheckpoint("test", 0, saveOptStates = true) mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true) - mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = Layout.NT))) + mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT"))) mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f)) assert(mod.getSymbol.toJson == mod2.getSymbol.toJson) mapEqu(mod.getParams._1, mod2.getParams._1) @@ -145,7 +145,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { var dShape = Shape(7, 20) val mod = new Module(sym, IndexedSeq("data"), null, contexts = Array(Context.cpu(0), Context.cpu(1))) - mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = Layout.NT))) + mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "NT"))) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 1f)) @@ -159,7 +159,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // reshape module dShape = Shape(14, 20) - mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = Layout.NT))) + mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT"))) mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape)), label = null, index = null, pad = 0)) @@ -170,7 +170,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // return to original binded shape dShape = Shape(7, 20) - mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = Layout.NT))) + mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT"))) mod.forward(new DataBatch( data = IndexedSeq(NDArray.ones(dShape)), label = null, index = null, pad = 0)) @@ -185,7 +185,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", - dataLayout = Layout.NCHW, labelLayout = Layout.NCHW) + dataLayout = "NCHW", labelLayout = "NCHW") // symbols var x = Symbol.Variable("data") @@ -236,7 +236,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", - dataLayout = Layout.NCHW, labelLayout = Layout.NCHW) + dataLayout = "NCHW", labelLayout = "NCHW") // symbols var x = Symbol.Variable("data") @@ -311,8 +311,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val mod = new Module(sym, IndexedSeq("data1", "data2")) mod.bind(dataShapes = IndexedSeq( - DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = Layout.NCHW)), - labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = Layout.N))) + DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = "NCHW")), + labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = "N"))) ) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f)) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala index 3bae4a703e4c..3bbd780d39b9 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala @@ -17,9 +17,10 @@ package org.apache.mxnetexamples.infer.imageclassifier -import org.apache.mxnet._ +import org.apache.mxnet.Shape import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.LoggerFactory +import org.apache.mxnet.{DType, DataDesc, Context} import org.apache.mxnet.infer.ImageClassifier import scala.collection.JavaConverters._ @@ -45,7 +46,7 @@ object ImageClassifierExample { val dType = DType.Float32 val inputShape = Shape(1, 3, 224, 224) - val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) + val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) // Create object of ImageClassifier class val imgClassifier: ImageClassifier = new @@ -66,7 +67,7 @@ object ImageClassifierExample { val dType = DType.Float32 val inputShape = Shape(1, 3, 224, 224) - val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) + val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) // Create object of ImageClassifier class val imgClassifier: ImageClassifier = new diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala index 2b8e49b8042f..c9707cb3ff6f 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala @@ -19,7 +19,7 @@ package org.apache.mxnetexamples.infer.objectdetector import java.io.File -import org.apache.mxnet._ +import org.apache.mxnet.{Context, DType, DataDesc, Shape} import org.apache.mxnet.infer._ import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.LoggerFactory @@ -58,7 +58,7 @@ object SSDClassifierExample { val inputShape = Shape(1, 3, 512, 512) // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...]) val outputShape = Shape(1, 6132, 6) - val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) + val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) val img = ImageClassifier.loadImageFromFile(inputImagePath) val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context) val output = objDetector.imageObjectDetect(img, Some(3)) @@ -73,7 +73,7 @@ object SSDClassifierExample { val inputShape = Shape(1, 3, 512, 512) // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...]) val outputShape = Shape(1, 6132, 6) - val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, Layout.NCHW)) + val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context) // Loading batch of images from the directory path val batchFiles = generateBatches(inputImageDir, 20) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 6410d7781295..0e865ab3c0ec 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -25,12 +25,12 @@ import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ import org.apache.commons.io.FileUtils + import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, NDArray, Shape, Symbol, Xavier} import org.apache.mxnet.DType.DType import org.apache.mxnet.optimizer.RMSProp import org.apache.mxnet.Executor import org.apache.mxnetexamples.Util -import org.apache.mxnet.Layout.Layout import scala.collection.immutable.ListMap import scala.sys.process.Process @@ -129,7 +129,7 @@ object ExampleMultiTask { override def getDType(): (DType, DType) = this.dataIter.getDType() - override def getLayout(): (Layout, Layout) = this.dataIter.getLayout() + override def getLayout(): (String, String) = this.dataIter.getLayout() // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this.dataIter.provideData diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 2bf9654802e6..22688c1db006 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -19,7 +19,6 @@ package org.apache.mxnetexamples.rnn import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -99,8 +98,8 @@ object BucketIo { _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, readContent: ReadContent = defaultReadContent, - dataLayout: Layout = Layout.NT, - labelLayout: Layout = Layout.N, + dataLayout: String = "NT", + labelLayout: String = "N", dataDType : DType = DType.Float32, labelDType: DType = DType.Int32) extends DataIter { @@ -256,7 +255,7 @@ object BucketIo { override def getDType(): (DType, DType) = (dataDType, labelDType) - override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) + override def getLayout(): (String, String) = (dataLayout, labelLayout) // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = this._provideLabel diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala index db5923efe5c5..8d31d1f6b3d6 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala @@ -49,11 +49,10 @@ class ImageClassifier(modelPathPrefix: String, extends Classifier(modelPathPrefix, inputDescriptors, contexts, epoch) { - protected[infer] val inputLayout = inputDescriptors.head.layout.toString + protected[infer] val inputLayout = inputDescriptors.head.layout require(inputDescriptors.nonEmpty, "Please provide input descriptor") - require(inputDescriptors.head.layout.toString == "NCHW", - "Provided layout doesn't match NCHW format") + require(inputDescriptors.head.layout == "NCHW", "Provided layout doesn't match NCHW format") protected[infer] val inputShape = inputDescriptors.head.shape diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala index 75b55209b1d2..2a4f03056372 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala @@ -18,7 +18,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter -import org.apache.mxnet._ +import org.apache.mxnet.{Context, DataDesc, NDArray, Shape} import org.apache.mxnet.module.Module import scala.collection.mutable.ListBuffer @@ -76,15 +76,15 @@ class Predictor(modelPathPrefix: String, private val logger = LoggerFactory.getLogger(classOf[Predictor]) - require(inputDescriptors.head.layout.toString.size != 0, "layout size should not be zero") + require(inputDescriptors.head.layout.size != 0, "layout size should not be zero") - protected[infer] var batchIndex = inputDescriptors(0).layout.toString.indexOf('N') + protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N') protected[infer] var batchSize = if (batchIndex != -1) inputDescriptors(0).shape(batchIndex) else 1 protected[infer] var iDescriptors = inputDescriptors - inputDescriptors.foreach((f: DataDesc) => require(f.layout.toString.indexOf('N') == batchIndex, + inputDescriptors.foreach((f: DataDesc) => require(f.layout.indexOf('N') == batchIndex, "batch size should be in the same index for all inputs")) if (batchIndex != -1) { @@ -94,7 +94,7 @@ class Predictor(modelPathPrefix: String, // Note: this is assuming that the input needs a batch logger.warn("InputDescriptor does not have batchSize, using 1 as the default batchSize") iDescriptors = inputDescriptors.map((f: DataDesc) => new DataDesc(f.name, - Shape(1 +: f.shape.toVector), f.dtype, Layout.getLayout('N' +: f.layout.toString))) + Shape(1 +: f.shape.toVector), f.dtype, 'N' +: f.layout)) batchIndex = 1 } diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala index 1eb3bd91e37e..948764ee8044 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.infer -import org.apache.mxnet._ +import org.apache.mxnet.{DType, DataDesc, Shape, NDArray, Context} + import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.BeforeAndAfterAll @@ -59,7 +60,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { test("ImageClassifierSuite-testConvertBufferedImageToNDArray") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 2, 2), - dType, Layout.NCHW)) + dType, "NCHW")) val image1 = new BufferedImage(100, 200, BufferedImage.TYPE_BYTE_GRAY) val image2 = ImageClassifier.reshapeImage(image1, 2, 2) @@ -72,7 +73,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { test("ImageClassifierSuite-testWithInputImage") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), - dType, Layout.NCHW)) + dType, "NCHW")) val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB) @@ -110,7 +111,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { test("ImageClassifierSuite-testWithInputBatchImage") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), - dType, Layout.NCHW)) + dType, "NCHW")) val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB) val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage) diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala index cdd5146d0a03..53fd7f310689 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala @@ -19,7 +19,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.module.{BaseModule, Module} -import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape} +import org.apache.mxnet.{DataDesc, NDArray, Shape} import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -45,7 +45,7 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { val mockPredictor = new MyPredictor("xyz", inputDescriptor) assert(mockPredictor.getBatchSize == 1) - assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.toString.indexOf('N')) + assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N')) val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)), new DataDesc("data", Shape(2, 3, 2, 2))) @@ -55,7 +55,7 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { } // batchsize is defaulted to 1 - val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = Layout.CHW)) + val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW")) val p2 = new MyPredictor("xyz", inputDescriptor) assert(p2.getBatchSize == 1, "should use a default batch size of 1") diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index b84c91807336..fbb8874477f2 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -18,7 +18,6 @@ package org.apache.mxnet.spark.io import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.apache.spark.mllib.regression.LabeledPoint @@ -37,8 +36,8 @@ class LabeledPointIter private[mxnet]( private val labelName: String = "label", private val dataDType: DType = DType.Float32, private val labelDType: DType = DType.Int32, - private val dataLayout: Layout = Layout.NCHW, - private val labelLayout: Layout = Layout.N) extends DataIter { + private val dataLayout: String = "NCHW", + private val labelLayout: String = "N") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -148,7 +147,7 @@ class LabeledPointIter private[mxnet]( override def getDType(): (DType, DType) = (dataDType, labelDType) - override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) + override def getLayout(): (String, String) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index 0d5068544adc..62bec0a5a5d6 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -18,7 +18,6 @@ package org.apache.mxnet.spark.io import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import org.apache.mxnet.{DataBatch, NDArray} /** @@ -30,8 +29,8 @@ class LongLivingDataBatch( override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], override val pad: Int, - override val dataLayout: Layout, - override val labelLayout: Layout, + override val dataLayout: String, + override val labelLayout: String, override val dataDType: DType, override val labelDType: DType) extends DataBatch(data, label, index, pad, dataLayout = dataLayout, labelLayout = labelLayout, diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index 96eacdab9ab0..a43906d5e365 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -18,7 +18,6 @@ package org.apache.mxnet.spark.io import org.apache.mxnet.DType.DType -import org.apache.mxnet.Layout.Layout import org.apache.mxnet._ import org.apache.spark.mllib.linalg.Vector @@ -37,8 +36,8 @@ class PointIter private[mxnet]( private val labelName: String = "label", private val dataDType: DType = DType.Float32, private val labelDType: DType = DType.Int32, - private val dataLayout: Layout = Layout.NCHW, - private val labelLayout: Layout = Layout.N) extends DataIter { + private val dataLayout: String = "NCHW", + private val labelLayout: String = "N") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -147,7 +146,7 @@ class PointIter private[mxnet]( override def getDType(): (DType, DType) = (dataDType, labelDType) - override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout) + override def getLayout(): (String, String) = (dataLayout, labelLayout) override def batchSize: Int = _batchSize From a0609d0210d53fd0bebbbdc08abc3a725f291e76 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 1 Aug 2018 12:18:07 -0700 Subject: [PATCH 11/25] fix coding style issue --- .../core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 4330ce36c130..1831d30b8066 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -47,6 +47,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], dataDType: DType, labelDType: DType, dataLayout: String, labelLayout: String) extends DataIter { + // scalastyle:off /** * @param data Specify the data. Data names will be data_0, data_1, ..., etc. * @param label Same as data, but is not fed to the model during testing. @@ -69,6 +70,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], IO.initData(label, allowEmpty = true, labelName), dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout) } + // scalastyle:on private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) From 75b518c39945f85cfbb80be4954c5d2eecd461cb Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 3 Aug 2018 11:52:30 -0700 Subject: [PATCH 12/25] print full stacktraces --- scala-package/pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 3511f4acfffd..c221b4721d81 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -231,6 +231,7 @@ ${skipTests} ${project.build.directory}/surefire-reports . + F WDF TestSuite.txt From ed92e73022bf5397e842affd1bc438637d9ae8b8 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 6 Aug 2018 19:17:12 -0700 Subject: [PATCH 13/25] apply changes to new constructor --- .../org/apache/mxnet/io/MXDataIter.scala | 19 +++++--- .../org/apache/mxnet/io/NDArrayIter.scala | 43 +++++++++++-------- .../test/scala/org/apache/mxnet/IOSuite.scala | 8 +++- .../scala/org/apache/mxnet/ModuleSuite.scala | 8 ++-- .../apache/mxnetexamples/rnn/BucketIo.scala | 22 +++++++--- .../mxnet/spark/io/LabeledPointIter.scala | 1 - .../org/apache/mxnet/spark/io/PointIter.scala | 1 - 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 292c8c49b965..b24f3a4d5970 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -31,17 +31,22 @@ import scala.collection.mutable.ListBuffer * @param handle the handle to the underlying C++ Data Iterator */ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, - dataName: String = "data", - labelName: String = "label", - dtype: DType = DType.Float32, - dataLayout: String = "NCHW", - labelLayout: String = "N", - dataDType: DType = DType.Float32, - labelDType: DType = DType.Int32) + dataName: String, + labelName: String, + dataLayout: String, + labelLayout: String, + dataDType: DType, + labelDType: DType) extends DataIter with WarnIfNotDisposed { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) + def this(handle: DataIterHandle, + dataName: String = "data", + labelName: String = "label") { + this(handle, dataName, labelName, "NCHW", "N", DType.Float32, DType.Int32) + } + // use currentBatch to implement hasNext // (may be this is not the best way to do this work, // fix me if any better way found) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 1831d30b8066..04d5f2d26f87 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -48,29 +48,36 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], dataLayout: String, labelLayout: String) extends DataIter { // scalastyle:off - /** - * @param data Specify the data. Data names will be data_0, data_1, ..., etc. - * @param label Same as data, but is not fed to the model during testing. - * Label names will be label_0, label_1, ..., etc. - * @param dataBatchSize Batch Size - * @param shuffle Whether to shuffle the data - * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch - * - * This iterator will pad, discard or roll over the last batch if - * the size of data does not match batch_size. Roll over is intended - * for training and can cause problems if used for prediction. - */ - def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, - dataBatchSize: Int = 1, shuffle: Boolean = false, - lastBatchHandle: String = "pad", - dataName: String = "data", labelName: String = "label", - dataDType: DType = MX_REAL_TYPE, labelDType: DType = DType.Int32, - dataLayout: String = "NCHW", labelLayout: String = "N") { + def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], + dataBatchSize: Int, shuffle: Boolean, + lastBatchHandle: String, + dataName: String, labelName: String, + dataDType: DType, labelDType: DType, + dataLayout: String, labelLayout: String) { this(IO.initData(data, allowEmpty = false, dataName), IO.initData(label, allowEmpty = true, labelName), dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout) } // scalastyle:on + /** + * @param data Specify the data. Data names will be data_0, data_1, ..., etc. + * @param label Same as data, but is not fed to the model during testing. + * Label names will be label_0, label_1, ..., etc. + * @param dataBatchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ + def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, + dataBatchSize: Int = 1, shuffle: Boolean = false, + lastBatchHandle: String = "pad", + dataName: String = "data", labelName: String = "label") { + this(data, label, dataBatchSize, shuffle, lastBatchHandle, dataName, labelName, + MX_REAL_TYPE, DType.Int32, "NCHW", "N") + } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 478f834d210b..10d723c964f9 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -252,6 +252,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test pad val dataIter0 = new NDArrayIter(data, label, 128, false, "pad", + dataName = "data", labelName = "label", dataDType = DType.Float32, labelDType = DType.Int32, dataLayout = "NTC", labelLayout = "NT") var batchCount = 0 val nBatch0 = 8 @@ -287,8 +288,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) // test empty label (for prediction) - val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard", - dataLayout = "NTC") + val dataIter2 = new NDArrayIter(data = data, label = IndexedSeq.empty, + dataBatchSize = 128, shuffle = false, lastBatchHandle = "discard", + dataName = "data", labelName = "label", + dataLayout = "NTC", labelLayout = "N", + dataDType = DType.Float32, labelDType = DType.Int32) batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index a73195f77207..f8a49ece70b9 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -184,8 +184,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", - dataLayout = "NCHW", labelLayout = "NCHW") + IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label", + DType.Float32, DType.Int32, "NCHW", "NCHW") // symbols var x = Symbol.Variable("data") @@ -235,8 +235,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label", - dataLayout = "NCHW", labelLayout = "NCHW") + IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label", + DType.Float32, DType.Int32, "NCHW", "NCHW") // symbols var x = Symbol.Variable("data") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 22688c1db006..44c03cee60d6 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -96,12 +96,22 @@ object BucketIo { class BucketSentenceIter( path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int], _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], - seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, - readContent: ReadContent = defaultReadContent, - dataLayout: String = "NT", - labelLayout: String = "N", - dataDType : DType = DType.Float32, - labelDType: DType = DType.Int32) extends DataIter { + seperateChar: String, text2Id: Text2Id, + readContent: ReadContent, + dataLayout: String, + labelLayout: String, + dataDType : DType, + labelDType: DType) extends DataIter { + + // scalastyle:off + def this(path: String, vocab: Map[String, Int], buckets: IndexedSeq[Int], + _batchSize: Int, initStates: IndexedSeq[(String, (Int, Int))], + seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, + readContent: ReadContent = defaultReadContent) { + this(path, vocab, buckets, _batchSize, initStates, seperateChar, text2Id, + readContent, "NT", "N", DType.Float32, DType.Int32) + } + // scalastyle:on private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index fbb8874477f2..7b62f7d201b6 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer /** * A helper converter for LabeledPoint - * @author Yizhi Liu */ class LabeledPointIter private[mxnet]( private val points: Iterator[LabeledPoint], diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index a43906d5e365..e1f0c1ad77cc 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer /** * A temporary helper implementation for predicting Vectors - * @author Yizhi Liu */ class PointIter private[mxnet]( private val points: Iterator[Vector], From 83b582628137d13e04a073205c415547b1ed9cc6 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 6 Aug 2018 21:28:32 -0700 Subject: [PATCH 14/25] add databatch bcc --- .../src/main/scala/org/apache/mxnet/IO.scala | 33 ++++++++++++++----- .../org/apache/mxnet/io/MXDataIter.scala | 1 + .../org/apache/mxnet/io/NDArrayIter.scala | 1 + .../org/apache/mxnet/io/PrefetchingIter.scala | 1 + .../multitask/ExampleMultiTask.scala | 3 +- .../mxnet/spark/io/LongLivingDataBatch.scala | 1 + 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 56cd59a9c24a..a0da1116c75e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -142,15 +142,29 @@ class DataBatch(val data: IndexedSeq[NDArray], val pad: Int, // the key for the bucket that should be used for this batch, // for bucketing io only - val bucketKey: AnyRef = null, + val bucketKey: AnyRef, // use ListMap to indicate the order of data/label loading // (must match the order of input data/label) - private val providedData: ListMap[String, Shape] = null, - private val providedLabel: ListMap[String, Shape] = null, - val dataDType: DType = Base.MX_REAL_TYPE, - val labelDType: DType = DType.Int32, - val dataLayout: String = "NCHW", - val labelLayout: String = "N") { + private val providedData: ListMap[String, Shape], + private val providedLabel: ListMap[String, Shape], + val dataDType: DType, + val labelDType: DType, + val dataLayout: String, + val labelLayout: String) { + def this(data: IndexedSeq[NDArray], + label: IndexedSeq[NDArray], + index: IndexedSeq[Long], + pad: Int, + // the key for the bucket that should be used for this batch, + // for bucketing io only + bucketKey: AnyRef = null, + // use ListMap to indicate the order of data/label loading + // (must match the order of input data/label) + providedData: ListMap[String, Shape] = null, + providedLabel: ListMap[String, Shape] = null) { + this(data, label, index, pad, bucketKey, providedData, providedLabel, + MX_REAL_TYPE, DType.Int32, "NCHW", "N") + } /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -320,8 +334,9 @@ abstract class DataIter extends Iterator[DataBatch] { @throws(classOf[NoSuchElementException]) def next(): DataBatch = { new DataBatch(getData(), getLabel(), getIndex(), getPad(), - dataDType = getDType()._1, labelDType = getDType()._2, - dataLayout = getLayout()._1, labelLayout = getLayout()._2) + null, null, null, + getDType()._1, getDType()._2, + getLayout()._1, getLayout()._2) } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index b24f3a4d5970..65828c2ed960 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -127,6 +127,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), index = getIndex(), pad = getPad(), + null, null, null, dataDType = getDType()._1, labelDType = getDType()._2, dataLayout = getLayout()._1, labelLayout = getLayout()._2) } else { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 04d5f2d26f87..000c465b5711 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -171,6 +171,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], if (hasNext) { cursor += dataBatchSize new DataBatch(getData(), getLabel(), getIndex(), getPad(), + null, null, null, dataDType = getDType()._1, labelDType = getDType()._2, dataLayout = getLayout()._1, labelLayout = getLayout()._2) } else { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index 2658476402ab..cfffd38a5672 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -224,6 +224,7 @@ class PrefetchingIter( labels.toIndexedSeq.flatten, nextBatch(0).index, nextBatch(0).pad, + null, null, null, dataLayout = nextBatch(0).dataLayout, labelLayout = nextBatch(0).labelLayout, dataDType = nextBatch(0).dataDType, diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 0e865ab3c0ec..affb69fd8b58 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -67,7 +67,8 @@ object ExampleMultiTask { new DataBatch(batch.data, IndexedSeq(label, label), batch.index, - batch.pad, dataDType = batch.dataDType, labelDType = batch.labelDType, + batch.pad, null, null, null, + dataDType = batch.dataDType, labelDType = batch.labelDType, dataLayout = batch.dataLayout, labelLayout = batch.labelLayout) } else { throw new NoSuchElementException diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index 62bec0a5a5d6..8df185d9e57f 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -33,6 +33,7 @@ class LongLivingDataBatch( override val labelLayout: String, override val dataDType: DType, override val labelDType: DType) extends DataBatch(data, label, index, pad, + null, null, null, dataLayout = dataLayout, labelLayout = labelLayout, dataDType = dataDType, labelDType = labelDType) { override def dispose(): Unit = {} From 0e170cfb50a703c6d9900b1a76288318fdaa4423 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 7 Aug 2018 13:15:11 -0700 Subject: [PATCH 15/25] introduce undefined field --- .../src/main/scala/org/apache/mxnet/IO.scala | 32 +++++++++++++------ .../main/scala/org/apache/mxnet/Layout.scala | 26 +++++++++++++++ .../org/apache/mxnet/io/MXDataIter.scala | 3 +- .../org/apache/mxnet/io/NDArrayIter.scala | 12 +++---- .../scala/org/apache/mxnet/ModuleSuite.scala | 2 +- .../apache/mxnetexamples/rnn/BucketIo.scala | 2 +- .../mxnet/infer/ObjectDetectorSuite.scala | 8 +++-- .../apache/mxnet/infer/PredictorSuite.scala | 16 ++++++---- .../mxnet/spark/io/LabeledPointIter.scala | 6 ++-- .../org/apache/mxnet/spark/io/PointIter.scala | 6 ++-- 10 files changed, 80 insertions(+), 33 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index a0da1116c75e..8089bc466ad5 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -105,8 +105,8 @@ object IO { checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) val dataName = params.getOrElse("data_name", "data") val labelName = params.getOrElse("label_name", "label") - val dataLayout = params.getOrElse("dataLayout", "NCHW") - val labelLayout = params.getOrElse("labelLayout", "N") + val dataLayout = params.getOrElse("dataLayout", Layout.UNDEFINED) + val labelLayout = params.getOrElse("labelLayout", Layout.UNDEFINED) val dataDType = params.getOrElse("dataDType", "Float32") val labelDType = params.getOrElse("labelDType", "Int32") new MXDataIter(out.value, dataName, labelName, @@ -163,7 +163,7 @@ class DataBatch(val data: IndexedSeq[NDArray], providedData: ListMap[String, Shape] = null, providedLabel: ListMap[String, Shape] = null) { this(data, label, index, pad, bucketKey, providedData, providedLabel, - MX_REAL_TYPE, DType.Int32, "NCHW", "N") + MX_REAL_TYPE, MX_REAL_TYPE, Layout.UNDEFINED, Layout.UNDEFINED) } /** * Dispose its data and labels @@ -194,10 +194,10 @@ object DataBatch { private var label: IndexedSeq[NDArray] = null private var index: IndexedSeq[Long] = null private var pad: Int = 0 - private var dataLayout: String = "NCHW" - private var labelLayout: String = "N" - private var dataDType: DType = Base.MX_REAL_TYPE - private var labelDType: DType = DType.Int32 + private var dataLayout: String = Layout.UNDEFINED + private var labelLayout: String = Layout.UNDEFINED + private var dataDType: DType = MX_REAL_TYPE + private var labelDType: DType = MX_REAL_TYPE private var bucketKey: AnyRef = null private var datatShapes: ListMap[String, Shape] = null private var labelShapes: ListMap[String, Shape] = null @@ -408,8 +408,9 @@ abstract class DataPack() extends Iterable[DataBatch] { // Named data desc description contains name, shape, type and other extended attributes. case class DataDesc(name: String, shape: Shape, - dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") { - require(shape.length == layout.length, ("number of dimensions in shape :%d with" + + dtype: DType = Base.MX_REAL_TYPE, layout: String = Layout.UNDEFINED) { + require(layout == Layout.UNDEFINED || shape.length == layout.length, + ("number of dimensions in shape :%d with" + " shape: %s should match the length of the layout: %d with layout: %s"). format(shape.length, shape.toString, layout.length, layout)) @@ -419,6 +420,8 @@ case class DataDesc(name: String, shape: Shape, } object DataDesc { + + private val logger = LoggerFactory.getLogger(classOf[DataDesc]) /** * Get the dimension that corresponds to the batch size. * @param layout layout string. For example, "NCHW". @@ -428,7 +431,16 @@ object DataDesc { * for each data-parallelism device. */ def getBatchAxis(layout: Option[String]): Int = { - layout.map(_.indexOf('N')).getOrElse(0) + if (layout.isEmpty|| layout.get == Layout.UNDEFINED) { + logger.info("Found Undefined Layout, will use default index 0") + 0 + } else { + if (layout.get.contains('N')) { + layout.get.indexOf("N") + } else { + throw new IllegalArgumentException("No N found in Batch Axis!") + } + } } implicit def ListMap2Descs(shapes: ListMap[String, Shape]): IndexedSeq[DataDesc] = { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala new file mode 100644 index 000000000000..822c579c705e --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +object Layout { + val UNDEFINED = "__undefined__" + val NCHW = "NCHW" + val NTC = "NTC" + val NT = "NT" + val N = "N" +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 65828c2ed960..9d3a9643d5d1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -44,7 +44,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, def this(handle: DataIterHandle, dataName: String = "data", labelName: String = "label") { - this(handle, dataName, labelName, "NCHW", "N", DType.Float32, DType.Int32) + this(handle, dataName, labelName, Layout.UNDEFINED, Layout.UNDEFINED, + DType.Float32, DType.Float32) } // use currentBatch to implement hasNext diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 000c465b5711..f8845d2d4a64 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -76,7 +76,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label") { this(data, label, dataBatchSize, shuffle, lastBatchHandle, dataName, labelName, - MX_REAL_TYPE, DType.Int32, "NCHW", "N") + MX_REAL_TYPE, MX_REAL_TYPE, Layout.UNDEFINED, Layout.UNDEFINED) } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) @@ -289,10 +289,10 @@ object NDArrayIter { private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" - private var dataLayout: String = "NCHW" - private var labelLayout: String = "N" + private var dataLayout: String = Layout.UNDEFINED + private var labelLayout: String = Layout.UNDEFINED private var dataDType: DType = Base.MX_REAL_TYPE - private var labelDType: DType = DType.Int32 + private var labelDType: DType = Base.MX_REAL_TYPE /** * Add one data input with its name. @@ -350,8 +350,8 @@ object NDArrayIter { /** * Set the layout. - * @param dataLayout The layout of the data, default is NCHW - * @param labelLayout The layout of the label, default is N + * @param dataLayout The layout of the data, default is UNDEFINED + * @param labelLayout The layout of the label, default is UNDEFINED * @return this */ def setLayout(dataLayout: String, labelLayout: String): Builder = { diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index f8a49ece70b9..f652de379597 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -24,7 +24,7 @@ import org.apache.mxnet.io._ class ModuleSuite extends FunSuite with BeforeAndAfterAll { test ("model dtype") { - val dType = DType.Float16 + val dType = DType.Float32 val dShape = Shape(3, 8, 7) var sym = Symbol.Variable("data") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 44c03cee60d6..46e2bd698381 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -109,7 +109,7 @@ object BucketIo { seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, readContent: ReadContent = defaultReadContent) { this(path, vocab, buckets, _batchSize, initStates, seperateChar, text2Id, - readContent, "NT", "N", DType.Float32, DType.Int32) + readContent, Layout.UNDEFINED, Layout.UNDEFINED, DType.Float32, DType.Float32) } // scalastyle:on diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala index 8160f0f6eb41..39139f8d3d2e 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala @@ -19,6 +19,8 @@ package org.apache.mxnet.infer // scalastyle:off import java.awt.image.BufferedImage + +import org.apache.mxnet.{DType, Layout} // scalastyle:on import org.apache.mxnet.Context import org.apache.mxnet.DataDesc @@ -69,7 +71,8 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll { } test("objectDetectWithInputImage") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), + DType.Float32, Layout.NCHW)) val inputImage = new BufferedImage(512, 512, BufferedImage.TYPE_INT_RGB) val testObjectDetector: ObjectDetector = new MyObjectDetector(modelPath, inputDescriptor) @@ -109,7 +112,8 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll { } test("objectDetectWithBatchImages") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), + DType.Float32, Layout.NCHW)) val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB) val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage) diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala index 53fd7f310689..509ffb35db8d 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala @@ -19,7 +19,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.module.{BaseModule, Module} -import org.apache.mxnet.{DataDesc, NDArray, Shape} +import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape} import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -40,15 +40,17 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { } test("PredictorSuite-testPredictorConstruction") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2), + layout = Layout.NCHW)) val mockPredictor = new MyPredictor("xyz", inputDescriptor) assert(mockPredictor.getBatchSize == 1) assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N')) - val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)), - new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2), + layout = Layout.NCHW), + new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW)) assertThrows[IllegalArgumentException] { val mockPredictor = new MyPredictor("xyz", inputDescriptor2) @@ -63,7 +65,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { test("PredictorSuite-testWithFlatArrays") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), + layout = Layout.NCHW)) val inputData = Array.fill[Float](12)(1) // this will disposed at the end of the predict call on Predictor. @@ -89,7 +92,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { } test("PredictorSuite-testWithNDArray") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), + layout = Layout.NCHW)) val inputData = NDArray.ones(Shape(1, 3, 2, 2)) // this will disposed at the end of the predict call on Predictor. diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index 7b62f7d201b6..7c37dbb85d80 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -34,9 +34,9 @@ class LabeledPointIter private[mxnet]( private val dataName: String = "data", private val labelName: String = "label", private val dataDType: DType = DType.Float32, - private val labelDType: DType = DType.Int32, - private val dataLayout: String = "NCHW", - private val labelLayout: String = "N") extends DataIter { + private val labelDType: DType = DType.Float32, + private val dataLayout: String = Layout.UNDEFINED, + private val labelLayout: String = Layout.UNDEFINED) extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index e1f0c1ad77cc..149c92e5b027 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -34,9 +34,9 @@ class PointIter private[mxnet]( private val dataName: String = "data", private val labelName: String = "label", private val dataDType: DType = DType.Float32, - private val labelDType: DType = DType.Int32, - private val dataLayout: String = "NCHW", - private val labelLayout: String = "N") extends DataIter { + private val labelDType: DType = DType.Float32, + private val dataLayout: String = Layout.UNDEFINED, + private val labelLayout: String = Layout.UNDEFINED) extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 From 67fed73007d5bc1a80fc8048391af74e77dac31a Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 7 Aug 2018 16:35:24 -0700 Subject: [PATCH 16/25] Fix crashes when change provideData to provideDataDesc It looks like if we want to force conversion from Float32 to Int32 will cause a crash on JVM. Need to be addressed. --- .../src/main/scala/org/apache/mxnet/IO.scala | 2 +- .../org/apache/mxnet/io/NDArrayIter.scala | 18 ++++++++---------- .../scala/org/apache/mxnet/ModuleSuite.scala | 8 ++++---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 8089bc466ad5..c02753ba2315 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -108,7 +108,7 @@ object IO { val dataLayout = params.getOrElse("dataLayout", Layout.UNDEFINED) val labelLayout = params.getOrElse("labelLayout", Layout.UNDEFINED) val dataDType = params.getOrElse("dataDType", "Float32") - val labelDType = params.getOrElse("labelDType", "Int32") + val labelDType = params.getOrElse("labelDType", "Float32") new MXDataIter(out.value, dataName, labelName, dataLayout = dataLayout, labelLayout = labelLayout, dataDType = DType.getType(dataDType), diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index f8845d2d4a64..454396aeeaee 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -115,18 +115,16 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private var cursor = -dataBatchSize private val (_provideData: ListMap[String, Shape], - _provideLabel: ListMap[String, Shape]) = { + _provideLabel: ListMap[String, Shape], + _provideDataDesc: IndexedSeq[DataDesc], + _provideLabelDesc: IndexedSeq[DataDesc]) = { val pData = ListMap.empty[String, Shape] ++ initData.map(getShape) val pLabel = ListMap.empty[String, Shape] ++ initLabel.map(getShape) - (pData, pLabel) - } - - private val (_provideDataDesc: IndexedSeq[DataDesc], - _provideLabelDesc: IndexedSeq[DataDesc]) = { - val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dataDType, dataLayout)) - val pLabel = initLabel.map(ele => - new DataDesc(ele._1, getShape(ele)._2, labelDType, labelLayout)) - (pData, pLabel) + val pDData = IndexedSeq.empty[DataDesc] ++ pData.map(ele => + new DataDesc(ele._1, ele._2, dataDType, dataLayout)) + val pDLabel = IndexedSeq.empty[DataDesc] ++ pLabel.map(ele => + new DataDesc(ele._1, ele._2, labelDType, labelLayout)) + (pData, pLabel, pDData, pDLabel) } /** diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index f652de379597..603c2a2c73a0 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -185,7 +185,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label", - DType.Float32, DType.Int32, "NCHW", "NCHW") + DType.Float32, DType.Float32, "NCHW", "NCHW") // symbols var x = Symbol.Variable("data") @@ -197,8 +197,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // create module val mod = new Module(x, contexts = Array(Context.cpu())) - mod.bind(dataShapes = trainData.provideData, - Option(trainData.provideLabel)) + mod.bind(dataShapes = trainData.provideDataDesc, + Option(trainData.provideLabelDesc)) val argParamsCorrect = Map( "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)), "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)), @@ -236,7 +236,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label", - DType.Float32, DType.Int32, "NCHW", "NCHW") + DType.Float32, DType.Float32, "NCHW", "NCHW") // symbols var x = Symbol.Variable("data") From 46754f083eaaba2406ee65ae28c4ec550481f0e7 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 9 Aug 2018 16:40:12 -0700 Subject: [PATCH 17/25] change spacing and revert test --- .../main/scala/org/apache/mxnet/DType.scala | 2 ++ .../src/main/scala/org/apache/mxnet/IO.scala | 12 +++++----- .../org/apache/mxnet/io/MXDataIter.scala | 4 ++-- .../test/scala/org/apache/mxnet/IOSuite.scala | 22 +++++-------------- .../scala/org/apache/mxnet/ModuleSuite.scala | 6 ++--- .../apache/mxnetexamples/gan/GanMnist.scala | 4 +--- .../imclassification/TrainMnist.scala | 8 ++----- .../apache/mxnetexamples/multitask/Data.scala | 8 ++----- 8 files changed, 22 insertions(+), 44 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala index b015bd2169b7..f3a8e8e9a4a5 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala @@ -42,6 +42,8 @@ object DType extends Enumeration { case "Float16" => DType.Float16 case "Float32" => DType.Float32 case "Float64" => DType.Float64 + case _ => throw new IllegalArgumentException( + s"DType: $dtypeStr not found! please set it in DType.scala") } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index c02753ba2315..5c8e3b735cdc 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -199,7 +199,7 @@ object DataBatch { private var dataDType: DType = MX_REAL_TYPE private var labelDType: DType = MX_REAL_TYPE private var bucketKey: AnyRef = null - private var datatShapes: ListMap[String, Shape] = null + private var dataShapes: ListMap[String, Shape] = null private var labelShapes: ListMap[String, Shape] = null /** @@ -285,10 +285,10 @@ object DataBatch { * @return this. */ def provideDataShape(name: String, shape: Shape): Builder = { - if (datatShapes == null) { - datatShapes = ListMap((name, shape)) + if (dataShapes == null) { + dataShapes = ListMap((name, shape)) } else { - datatShapes = datatShapes.updated(name, shape) + dataShapes = dataShapes.updated(name, shape) } this } @@ -310,7 +310,7 @@ object DataBatch { def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, + new DataBatch(data, label, index, pad, bucketKey, dataShapes, labelShapes, dataDType, labelDType, dataLayout, labelLayout) } } @@ -438,7 +438,7 @@ object DataDesc { if (layout.get.contains('N')) { layout.get.indexOf("N") } else { - throw new IllegalArgumentException("No N found in Batch Axis!") + throw new IllegalArgumentException("No N found in Batch Axis!") } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 9d3a9643d5d1..0eaae62401fd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -55,8 +55,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, private val (_provideDataDesc: IndexedSeq[DataDesc], _provideLabelDesc: IndexedSeq[DataDesc], - _provideData: ListMap[String, Shape], - _provideLabel: ListMap[String, Shape], + _provideData: ListMap[String, Shape], + _provideLabel: ListMap[String, Shape], _batchSize: Int) = { if (hasNext) { iterNext() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 10d723c964f9..4cd237fc8c80 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -101,9 +101,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "data_shape" -> "(3,28,28)", "batch_size" -> "100", "preprocess_threads" -> "4", - "prefetch_buffer" -> "1", - "dataLayout" -> "NCHW", - "labelLayout" -> "N" + "prefetch_buffer" -> "1" ) val imgRecIter = IO.ImageRecordIter(params) val nBatch = 500 @@ -149,9 +147,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10", - "dataLayout" -> "NT", - "labelLayout" -> "N" + "seed" -> "10" ) val mnistIter = IO.MNISTIter(params) @@ -188,9 +184,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10", - "dataLayout" -> "NT", - "labelLayout" -> "N" + "seed" -> "10" ) val mnistPack1 = IO.MNISTPack(params) @@ -252,8 +246,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test pad val dataIter0 = new NDArrayIter(data, label, 128, false, "pad", - dataName = "data", labelName = "label", dataDType = DType.Float32, labelDType = DType.Int32, - dataLayout = "NTC", labelLayout = "NT") + dataName = "data", labelName = "label") var batchCount = 0 val nBatch0 = 8 while(dataIter0.hasNext) { @@ -272,7 +265,6 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { .addData("data0", data(0)).addData("data1", data(1)) .addLabel("label", label(0)) .setBatchSize(128) - .setLayout("NTC", "NT") .setLastBatchHandle("discard").build() val nBatch1 = 7 batchCount = 0 @@ -288,11 +280,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) // test empty label (for prediction) - val dataIter2 = new NDArrayIter(data = data, label = IndexedSeq.empty, - dataBatchSize = 128, shuffle = false, lastBatchHandle = "discard", - dataName = "data", labelName = "label", - dataLayout = "NTC", labelLayout = "N", - dataDType = DType.Float32, labelDType = DType.Int32) + val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, shuffle = false, lastBatchHandle = "discard") batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 603c2a2c73a0..88e314e2a72c 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -184,8 +184,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label", - DType.Float32, DType.Float32, "NCHW", "NCHW") + IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") // symbols var x = Symbol.Variable("data") @@ -235,8 +234,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2)) val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2)) val trainData = new NDArrayIter( - IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label", - DType.Float32, DType.Float32, "NCHW", "NCHW") + IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label") // symbols var x = Symbol.Variable("data") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala index f145c189148e..70846eebfb8e 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala @@ -130,9 +130,7 @@ object GanMnist { "label" -> s"$dataPath/train-labels-idx1-ubyte", "input_shape" -> s"(1, 28, 28)", "batch_size" -> s"$batchSize", - "shuffle" -> "True", - "dataLayout" -> "NCHW", - "labelLayout" -> "N" + "shuffle" -> "True" ) val mnistIter = IO.MNISTIter(params) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala index 4fce7235ed7f..bd0ce45ffe5f 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala @@ -77,9 +77,7 @@ object TrainMnist { "shuffle" -> "True", "flat" -> flat, "num_parts" -> kv.numWorkers.toString, - "part_index" -> kv.`rank`.toString, - "dataLayout" -> "NT", - "labelLayout" -> "N")) + "part_index" -> kv.`rank`.toString)) val eval = IO.MNISTIter(Map( "image" -> (dataDir + "t10k-images-idx3-ubyte"), @@ -89,9 +87,7 @@ object TrainMnist { "batch_size" -> batchSize.toString, "flat" -> flat, "num_parts" -> kv.numWorkers.toString, - "part_index" -> kv.`rank`.toString, - "dataLayout" -> "NT", - "labelLayout" -> "N")) + "part_index" -> kv.`rank`.toString)) (train, eval) } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala index 2b0a20b40e76..068aa6314f89 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala @@ -32,9 +32,7 @@ object Data { "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", "shuffle" -> "True", - "flat" -> flat, - "dataLayout" -> "NT", - "labelLayout" -> "N" + "flat" -> flat ) val trainDataIter = IO.MNISTIter(trainParams) val testParams = Map( @@ -42,9 +40,7 @@ object Data { "label" -> s"$dataPath/t10k-labels-idx1-ubyte", "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", - "flat" -> flat, - "dataLayout" -> "NT", - "labelLayout" -> "N" + "flat" -> flat ) val testDataIter = IO.MNISTIter(testParams) (trainDataIter, testDataIter) From 8c7fac12de63e2040bec2c464eb76e5f507086b5 Mon Sep 17 00:00:00 2001 From: Qing Date: Sun, 12 Aug 2018 22:11:19 -0700 Subject: [PATCH 18/25] apply DataDesc on DataBatch --- .../src/main/scala/org/apache/mxnet/IO.scala | 103 +++++++----------- .../org/apache/mxnet/io/MXDataIter.scala | 23 +--- .../org/apache/mxnet/io/NDArrayIter.scala | 19 +--- .../org/apache/mxnet/io/PrefetchingIter.scala | 22 +--- .../org/apache/mxnet/io/ResizeIter.scala | 16 --- .../test/scala/org/apache/mxnet/IOSuite.scala | 3 +- .../multitask/ExampleMultiTask.scala | 8 +- .../apache/mxnetexamples/rnn/BucketIo.scala | 8 +- .../mxnet/spark/io/LabeledPointIter.scala | 7 +- .../mxnet/spark/io/LongLivingDataBatch.scala | 10 +- .../org/apache/mxnet/spark/io/PointIter.scala | 7 +- 11 files changed, 50 insertions(+), 176 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 5c8e3b735cdc..e3889808e3bb 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -143,14 +143,10 @@ class DataBatch(val data: IndexedSeq[NDArray], // the key for the bucket that should be used for this batch, // for bucketing io only val bucketKey: AnyRef, - // use ListMap to indicate the order of data/label loading + // use DataDesc to indicate the order of data/label loading // (must match the order of input data/label) - private val providedData: ListMap[String, Shape], - private val providedLabel: ListMap[String, Shape], - val dataDType: DType, - val labelDType: DType, - val dataLayout: String, - val labelLayout: String) { + private val providedDataDesc: IndexedSeq[DataDesc], + private val providedLabelDesc: IndexedSeq[DataDesc]) { def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], index: IndexedSeq[Long], @@ -162,8 +158,8 @@ class DataBatch(val data: IndexedSeq[NDArray], // (must match the order of input data/label) providedData: ListMap[String, Shape] = null, providedLabel: ListMap[String, Shape] = null) { - this(data, label, index, pad, bucketKey, providedData, providedLabel, - MX_REAL_TYPE, MX_REAL_TYPE, Layout.UNDEFINED, Layout.UNDEFINED) + this(data, label, index, pad, bucketKey, + DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel)) } /** * Dispose its data and labels @@ -179,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray], } // The name and shape of data - def provideData: ListMap[String, Shape] = providedData + def provideData: ListMap[String, Shape] = { + var temp = ListMap[String, Shape]() + if (providedDataDesc == null) null + else { + providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape)) + temp + } + } // The name and shape of label - def provideLabel: ListMap[String, Shape] = providedLabel + def provideLabel: ListMap[String, Shape] = { + var temp = ListMap[String, Shape]() + if (providedLabelDesc == null) null + else { + providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape)) + temp + } + } + + def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc + + def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc + } object DataBatch { @@ -194,13 +209,9 @@ object DataBatch { private var label: IndexedSeq[NDArray] = null private var index: IndexedSeq[Long] = null private var pad: Int = 0 - private var dataLayout: String = Layout.UNDEFINED - private var labelLayout: String = Layout.UNDEFINED - private var dataDType: DType = MX_REAL_TYPE - private var labelDType: DType = MX_REAL_TYPE private var bucketKey: AnyRef = null - private var dataShapes: ListMap[String, Shape] = null - private var labelShapes: ListMap[String, Shape] = null + private var dataShapes: IndexedSeq[DataDesc] = null + private var labelShapes: IndexedSeq[DataDesc] = null /** * Set the input data. @@ -244,30 +255,6 @@ object DataBatch { this } - /** - * Set the dtype. - * @param dataDType The dtype of the data, default is Float32 - * @param labelDType The dtype of the label, default is Int32 - * @return this - */ - def setDType(dataDType: DType, labelDType: DType): Builder = { - this.dataDType = dataDType - this.labelDType = labelDType - this - } - - /** - * Set the layout. - * @param dataLayout The layout of the data, default is NCHW - * @param labelLayout The layout of the label, default is N - * @return this - */ - def setLayout(dataLayout: String, labelLayout: String): Builder = { - this.dataLayout = dataLayout - this.labelLayout = labelLayout - this - } - /** * Set the bucket key, used for bucketing module. * @param bucketKey the bucket key related to this batch. @@ -280,15 +267,14 @@ object DataBatch { /** * Provide the shape of a data. - * @param name data name. - * @param shape data shape. + * @param dataDesc DataDescriptor * @return this. */ - def provideDataShape(name: String, shape: Shape): Builder = { + def provideDataShape(dataDesc: DataDesc): Builder = { if (dataShapes == null) { - dataShapes = ListMap((name, shape)) + dataShapes = IndexedSeq(dataDesc) } else { - dataShapes = dataShapes.updated(name, shape) + dataShapes = dataShapes ++ IndexedSeq(dataDesc) } this } @@ -299,19 +285,18 @@ object DataBatch { * @param shape label shape. * @return this. */ - def provideLabelShape(name: String, shape: Shape): Builder = { + def provideLabelShape(dataDesc: DataDesc): Builder = { if (labelShapes == null) { - labelShapes = ListMap((name, shape)) + labelShapes = IndexedSeq(dataDesc) } else { - labelShapes = labelShapes.updated(name, shape) + labelShapes = labelShapes ++ IndexedSeq(dataDesc) } this } def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, dataShapes, labelShapes, - dataDType, labelDType, dataLayout, labelLayout) + new DataBatch(data, label, index, pad, bucketKey, dataShapes, labelShapes) } } } @@ -334,9 +319,7 @@ abstract class DataIter extends Iterator[DataBatch] { @throws(classOf[NoSuchElementException]) def next(): DataBatch = { new DataBatch(getData(), getLabel(), getIndex(), getPad(), - null, null, null, - getDType()._1, getDType()._2, - getLayout()._1, getLayout()._2) + null, null, null) } /** @@ -358,18 +341,6 @@ abstract class DataIter extends Iterator[DataBatch] { */ def getPad(): Int - /** - * Get the DType - * @return data and label DType of the DataIter - */ - def getDType(): (DType, DType) - - /** - * Get the layout - * @return data and label layout of the DataIter - */ - def getLayout(): (String, String) - /** * Get the index of current batch * @return the index of current batch diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 0eaae62401fd..c18ab743de69 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -62,12 +62,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, iterNext() val data = currentBatch.data(0) val label = currentBatch.label(0) - val dataType = currentBatch.dataDType - val labelDType = currentBatch.labelDType - val dataLayout = currentBatch.dataLayout - val labelLayout = currentBatch.labelLayout // properties - val res = (IndexedSeq(new DataDesc(dataName, data.shape, dataDType, dataLayout)), + val res = ( + IndexedSeq(new DataDesc(dataName, data.shape, dataDType, dataLayout)), IndexedSeq(new DataDesc(labelName, label.shape, labelDType, labelLayout)), ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), @@ -128,9 +125,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), index = getIndex(), pad = getPad(), - null, null, null, - dataDType = getDType()._1, labelDType = getDType()._2, - dataLayout = getLayout()._1, labelLayout = getLayout()._2) + null, null, null) } else { currentBatch = null } @@ -179,18 +174,6 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, out.value } - /** - * Get the DType - * @return DType - */ - def getDType(): (DType, DType) = (dataDType, labelDType) - - /** - * Get the layout - * @return layout - */ - def getLayout(): (String, String) = (dataLayout, labelLayout) - // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 454396aeeaee..90c4f71f294b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -169,9 +169,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], if (hasNext) { cursor += dataBatchSize new DataBatch(getData(), getLabel(), getIndex(), getPad(), - null, null, null, - dataDType = getDType()._1, labelDType = getDType()._2, - dataLayout = getLayout()._1, labelLayout = getLayout()._2) + null, null, null) } else { throw new NoSuchElementException } @@ -246,21 +244,6 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], } } - /** - * Get the DType - * @return DType - */ - def getDType(): (DType, DType) = { - (dataDType, labelDType) - } - - /** - * Get the layout - * @return layout - */ - def getLayout(): (String, String) = { - (dataLayout, labelLayout) - } // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = _provideData diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index cfffd38a5672..f791a3101f36 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -177,22 +177,6 @@ class PrefetchingIter( */ override def getPad(): Int = this.currentBatch.pad - /** - * Get the DType - * @return DType - */ - def getDType(): (DType, DType) = { - (currentBatch.dataDType, currentBatch.labelDType) - } - - /** - * Get the layout - * @return layout - */ - def getLayout(): (String, String) = { - (currentBatch.dataLayout, currentBatch.labelLayout) - } - // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = this._provideLabel @@ -224,11 +208,7 @@ class PrefetchingIter( labels.toIndexedSeq.flatten, nextBatch(0).index, nextBatch(0).pad, - null, null, null, - dataLayout = nextBatch(0).dataLayout, - labelLayout = nextBatch(0).labelLayout, - dataDType = nextBatch(0).dataDType, - labelDType = nextBatch(0).labelDType) + null, null, null) for (e <- dataTaken) e.release() true } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index f316709a404c..c86ed4e8f7e2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -129,22 +129,6 @@ class ResizeIter( currentBatch.pad } - /** - * Get the DType - * @return DType - */ - def getDType(): (DType, DType) = { - (currentBatch.dataDType, currentBatch.labelDType) - } - - /** - * Get the layout - * @return layout - */ - def getLayout(): (String, String) = { - (currentBatch.dataLayout, currentBatch.labelLayout) - } - override def batchSize: Int = { dataIter.batchSize } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 4cd237fc8c80..9403dc3e191c 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -280,7 +280,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) // test empty label (for prediction) - val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, shuffle = false, lastBatchHandle = "discard") + val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, shuffle = false, + lastBatchHandle = "discard") batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index affb69fd8b58..3e797a7919ef 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -67,9 +67,7 @@ object ExampleMultiTask { new DataBatch(batch.data, IndexedSeq(label, label), batch.index, - batch.pad, null, null, null, - dataDType = batch.dataDType, labelDType = batch.labelDType, - dataLayout = batch.dataLayout, labelLayout = batch.labelLayout) + batch.pad, null, null, null) } else { throw new NoSuchElementException } @@ -128,10 +126,6 @@ object ExampleMultiTask { */ override def getPad(): Int = this.dataIter.getPad() - override def getDType(): (DType, DType) = this.dataIter.getDType() - - override def getLayout(): (String, String) = this.dataIter.getLayout() - // The name and shape of data provided by this iterator override def provideData: ListMap[String, Shape] = this.dataIter.provideData diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 46e2bd698381..adfe74597479 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -223,9 +223,7 @@ object BucketIo { getIndex(), getPad(), this.buckets(bucketIdx).asInstanceOf[AnyRef], - batchProvideData, batchProvideLabel, - getDType()._1, getDType()._2, - getLayout()._1, getLayout()._2) + batchProvideData, batchProvideLabel) } /** @@ -263,10 +261,6 @@ object BucketIo { */ override def getPad(): Int = 0 - override def getDType(): (DType, DType) = (dataDType, labelDType) - - override def getLayout(): (String, String) = (dataLayout, labelLayout) - // The name and shape of label provided by this iterator override def provideLabel: ListMap[String, Shape] = this._provideLabel diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index 7c37dbb85d80..cf5efc6b9045 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -76,8 +76,7 @@ class LabeledPointIter private[mxnet]( } val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( - IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - dataLayout, labelLayout, dataDType, labelDType) + IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad) cache += dataBatch dataBatch } @@ -144,10 +143,6 @@ class LabeledPointIter private[mxnet]( */ override def getPad(): Int = 0 - override def getDType(): (DType, DType) = (dataDType, labelDType) - - override def getLayout(): (String, String) = (dataLayout, labelLayout) - override def batchSize: Int = _batchSize override def hasNext: Boolean = { diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index 8df185d9e57f..e3272a4066b5 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -28,14 +28,8 @@ class LongLivingDataBatch( override val data: IndexedSeq[NDArray], override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], - override val pad: Int, - override val dataLayout: String, - override val labelLayout: String, - override val dataDType: DType, - override val labelDType: DType) extends DataBatch(data, label, index, pad, - null, null, null, - dataLayout = dataLayout, labelLayout = labelLayout, - dataDType = dataDType, labelDType = labelDType) { + override val pad: Int) extends DataBatch(data, label, index, pad, + null, null, null) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index 149c92e5b027..7cc43666c532 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -75,8 +75,7 @@ class PointIter private[mxnet]( } val pad = batchSize - instNum val dataBatch = new LongLivingDataBatch( - IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad, - dataLayout, labelLayout, dataDType, labelDType) + IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad) cache += dataBatch dataBatch } @@ -143,10 +142,6 @@ class PointIter private[mxnet]( */ override def getPad(): Int = 0 - override def getDType(): (DType, DType) = (dataDType, labelDType) - - override def getLayout(): (String, String) = (dataLayout, labelLayout) - override def batchSize: Int = _batchSize override def hasNext: Boolean = { From 70aa7f47cb1df4a63fcb9f3a15d54f8c38300f07 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 13 Aug 2018 20:59:47 -0700 Subject: [PATCH 19/25] unit test for NDArrayIter and MXDataiter --- .../test/scala/org/apache/mxnet/IOSuite.scala | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 9403dc3e191c..2d55f6ece3e7 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -40,7 +40,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "silent" -> "0", "seed" -> "10", "dataLayout" -> "NT", - "labelLayout" -> "N" + "labelLayout" -> "N", + "dataDType" -> "Float32", + "labelDType" -> "Int32" ) val mnistPack = IO.MNISTPack(params) @@ -59,6 +61,12 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val provideLabel = mnistIter.provideLabel assert(provideData("data") === Shape(100, 784)) assert(provideLabel("label") === Shape(100)) + val provideDataDesc = mnistIter.provideDataDesc + val provideLabelDesc = mnistIter.provideLabelDesc + assert(provideDataDesc(0).dtype == DType.Float32) + assert(provideDataDesc(0).layout == Layout.NT) + assert(provideLabelDesc(0).dtype == DType.Int32) + assert(provideLabelDesc(0).layout == Layout.N) // test_loop mnistIter.reset() batchCount = 0 @@ -293,5 +301,16 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) assert(dataIter2.initLabel == IndexedSeq.empty) + + // test implementation with DataDesc + val dataIter3 = new NDArrayIter(data, label, 128, false, + "pad", "data", "label", DType.Float32, + DType.Int32, Layout.NTC, Layout.NT) + val dataDesc = dataIter3.provideDataDesc + val labelDesc = dataIter3.provideLabelDesc + assert(dataDesc(0).dtype == DType.Float32) + assert(dataDesc(0).layout == Layout.NTC) + assert(labelDesc(0).dtype == DType.Int32) + assert(labelDesc(0).layout == Layout.NT) } } From e8d1a400c20b5cb1ff66b2788081b47e3b31bfa9 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 13 Aug 2018 21:06:36 -0700 Subject: [PATCH 20/25] apply changes on CR --- .../core/src/main/scala/org/apache/mxnet/IO.scala | 2 +- .../core/src/main/scala/org/apache/mxnet/Layout.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index e3889808e3bb..21bb80148d26 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -409,7 +409,7 @@ object DataDesc { if (layout.get.contains('N')) { layout.get.indexOf("N") } else { - throw new IllegalArgumentException("No N found in Batch Axis!") + throw new IllegalArgumentException("no Batch Axis('N') found in Layout!") } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala index 822c579c705e..cb75dbc40803 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala @@ -17,6 +17,15 @@ package org.apache.mxnet +/** + * Layout definition of DataDesc + * N Batch size + * C channels + * H Height + * W Weight + * T sequence length + * __undefined__ default value of Layout + */ object Layout { val UNDEFINED = "__undefined__" val NCHW = "NCHW" From 44bb97eb7b9be1e83db8a18e6437d6cc06fc3e95 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 14 Aug 2018 15:16:40 -0700 Subject: [PATCH 21/25] change NDArrayIter and revert the rest --- .../src/main/scala/org/apache/mxnet/IO.scala | 28 +++-- .../org/apache/mxnet/io/MXDataIter.scala | 22 ++-- .../org/apache/mxnet/io/NDArrayIter.scala | 106 ++++++++---------- .../org/apache/mxnet/io/PrefetchingIter.scala | 8 +- .../org/apache/mxnet/io/ResizeIter.scala | 2 + .../test/scala/org/apache/mxnet/IOSuite.scala | 19 +--- .../multitask/ExampleMultiTask.scala | 2 + .../apache/mxnetexamples/rnn/BucketIo.scala | 44 +++----- .../mxnet/spark/io/LabeledPointIter.scala | 14 +-- .../org/apache/mxnet/spark/io/PointIter.scala | 14 +-- 10 files changed, 110 insertions(+), 149 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 21bb80148d26..77917c261c08 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -105,29 +105,26 @@ object IO { checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) val dataName = params.getOrElse("data_name", "data") val labelName = params.getOrElse("label_name", "label") - val dataLayout = params.getOrElse("dataLayout", Layout.UNDEFINED) - val labelLayout = params.getOrElse("labelLayout", Layout.UNDEFINED) - val dataDType = params.getOrElse("dataDType", "Float32") - val labelDType = params.getOrElse("labelDType", "Float32") - new MXDataIter(out.value, dataName, labelName, - dataLayout = dataLayout, labelLayout = labelLayout, - dataDType = DType.getType(dataDType), - labelDType = DType.getType(labelDType)) + new MXDataIter(out.value, dataName, labelName) } // Convert data into canonical form. - private[mxnet] def initData(data: IndexedSeq[NDArray], - allowEmpty: Boolean, - defaultName: String): IndexedSeq[(String, NDArray)] = { + private[mxnet] def initDataDesc(data: IndexedSeq[NDArray], + allowEmpty: Boolean, + defaultName: String, + defaultDType: DType, + defaultLayout: String): IndexedSeq[(DataDesc, NDArray)] = { require(data != null) require(data != IndexedSeq.empty || allowEmpty) if (data == IndexedSeq.empty) { IndexedSeq() } else if (data.length == 1) { - IndexedSeq((defaultName, data(0))) + IndexedSeq((new DataDesc(defaultName, data(0).shape, + defaultDType, defaultLayout), data(0))) } else { data.zipWithIndex.map(item => { - (defaultName + "_" + item._2, item._1) + (new DataDesc(defaultName + "_" + item._2, item._1.shape, + defaultDType, defaultLayout), item._1) }).toIndexedSeq } } @@ -379,7 +376,7 @@ abstract class DataPack() extends Iterable[DataBatch] { // Named data desc description contains name, shape, type and other extended attributes. case class DataDesc(name: String, shape: Shape, - dtype: DType = Base.MX_REAL_TYPE, layout: String = Layout.UNDEFINED) { + dtype: DType = DType.Float32, layout: String = Layout.UNDEFINED) { require(layout == Layout.UNDEFINED || shape.length == layout.length, ("number of dimensions in shape :%d with" + " shape: %s should match the length of the layout: %d with layout: %s"). @@ -403,7 +400,7 @@ object DataDesc { */ def getBatchAxis(layout: Option[String]): Int = { if (layout.isEmpty|| layout.get == Layout.UNDEFINED) { - logger.info("Found Undefined Layout, will use default index 0") + logger.warn("Found Undefined Layout, will use default index 0") 0 } else { if (layout.get.contains('N')) { @@ -414,6 +411,7 @@ object DataDesc { } } + @deprecated implicit def ListMap2Descs(shapes: ListMap[String, Shape]): IndexedSeq[DataDesc] = { if (shapes != null) { shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index c18ab743de69..f7f858deb82d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -31,23 +31,12 @@ import scala.collection.mutable.ListBuffer * @param handle the handle to the underlying C++ Data Iterator */ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, - dataName: String, - labelName: String, - dataLayout: String, - labelLayout: String, - dataDType: DType, - labelDType: DType) + dataName: String = "data", + labelName: String = "label") extends DataIter with WarnIfNotDisposed { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) - def this(handle: DataIterHandle, - dataName: String = "data", - labelName: String = "label") { - this(handle, dataName, labelName, Layout.UNDEFINED, Layout.UNDEFINED, - DType.Float32, DType.Float32) - } - // use currentBatch to implement hasNext // (may be this is not the best way to do this work, // fix me if any better way found) @@ -64,8 +53,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, val label = currentBatch.label(0) // properties val res = ( - IndexedSeq(new DataDesc(dataName, data.shape, dataDType, dataLayout)), - IndexedSeq(new DataDesc(labelName, label.shape, labelDType, labelLayout)), + // TODO: need to allow user to specify DType and Layout + IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)), + IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)), ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), data.shape(0)) @@ -175,9 +165,11 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = _provideData // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = _provideLabel // Provide type:DataDesc of the data diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 90c4f71f294b..ec32f94708ef 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -40,25 +40,11 @@ import scala.collection.immutable.ListMap * the size of data does not match batch_size. Roll over is intended * for training and can cause problems if used for prediction. */ -class NDArrayIter(data: IndexedSeq[(String, NDArray)], - label: IndexedSeq[(String, NDArray)], +class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], + label: IndexedSeq[(DataDesc, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, - lastBatchHandle: String, - dataDType: DType, labelDType: DType, - dataLayout: String, labelLayout: String) extends DataIter { - - // scalastyle:off - def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], - dataBatchSize: Int, shuffle: Boolean, - lastBatchHandle: String, - dataName: String, labelName: String, - dataDType: DType, labelDType: DType, - dataLayout: String, labelLayout: String) { - this(IO.initData(data, allowEmpty = false, dataName), - IO.initData(label, allowEmpty = true, labelName), - dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout) - } - // scalastyle:on + lastBatchHandle: String) extends DataIter { + /** * @param data Specify the data. Data names will be data_0, data_1, ..., etc. * @param label Same as data, but is not fed to the model during testing. @@ -75,13 +61,14 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label") { - this(data, label, dataBatchSize, shuffle, lastBatchHandle, dataName, labelName, - MX_REAL_TYPE, MX_REAL_TYPE, Layout.UNDEFINED, Layout.UNDEFINED) + this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, Layout.UNDEFINED), + IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED), + dataBatchSize, shuffle, lastBatchHandle) } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) - val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = { + val (initData: IndexedSeq[(DataDesc, NDArray)], initLabel: IndexedSeq[(DataDesc, NDArray)]) = { // data should not be null and size > 0 require(data != null && data.size > 0, "data should not be null and data.size should not be zero") @@ -120,10 +107,14 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], _provideLabelDesc: IndexedSeq[DataDesc]) = { val pData = ListMap.empty[String, Shape] ++ initData.map(getShape) val pLabel = ListMap.empty[String, Shape] ++ initLabel.map(getShape) - val pDData = IndexedSeq.empty[DataDesc] ++ pData.map(ele => - new DataDesc(ele._1, ele._2, dataDType, dataLayout)) - val pDLabel = IndexedSeq.empty[DataDesc] ++ pLabel.map(ele => - new DataDesc(ele._1, ele._2, labelDType, labelLayout)) + val pDData = IndexedSeq.empty[DataDesc] ++ initData.map(ele => { + val temp = getShape(ele) + new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout) + }) + val pDLabel = IndexedSeq.empty[DataDesc] ++ initLabel.map(ele => { + val temp = getShape(ele) + new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout) + }) (pData, pLabel, pDData, pDLabel) } @@ -131,10 +122,10 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], * get shape via dataBatchSize * @param dataItem */ - private def getShape(dataItem: (String, NDArray)): (String, Shape) = { + private def getShape(dataItem: (DataDesc, NDArray)): (String, Shape) = { val len = dataItem._2.shape.size val newShape = dataItem._2.shape.slice(1, len) - (dataItem._1, Shape(Array[Int](dataBatchSize)) ++ newShape) + (dataItem._1.name, Shape(Array[Int](dataBatchSize)) ++ newShape) } @@ -193,7 +184,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], } } - private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = { + private def _getData(data: IndexedSeq[(DataDesc, NDArray)]): IndexedSeq[NDArray] = { require(cursor < numData, "DataIter needs reset.") if (data == null) { null @@ -246,9 +237,11 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = _provideData // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = _provideLabel // Provide type:DataDesc of the data @@ -266,14 +259,10 @@ object NDArrayIter { * Builder class for NDArrayIter. */ class Builder() { - private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty - private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty + private var data: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty + private var label: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" - private var dataLayout: String = Layout.UNDEFINED - private var labelLayout: String = Layout.UNDEFINED - private var dataDType: DType = Base.MX_REAL_TYPE - private var labelDType: DType = Base.MX_REAL_TYPE /** * Add one data input with its name. @@ -282,7 +271,8 @@ object NDArrayIter { * @return The builder object itself. */ def addData(name: String, data: NDArray): Builder = { - this.data = this.data ++ IndexedSeq((name, data)) + this.data = this.data ++ IndexedSeq((new DataDesc(name, + data.shape, DType.Float32, Layout.UNDEFINED), data)) this } @@ -293,7 +283,24 @@ object NDArrayIter { * @return The builder object itself. */ def addLabel(name: String, label: NDArray): Builder = { - this.label = this.label ++ IndexedSeq((name, label)) + this.label = this.label ++ IndexedSeq((new DataDesc(name, + label.shape, DType.Float32, Layout.UNDEFINED), label)) + this + } + + /** + * Add one data input with its DataDesc + */ + def addDataDesc(dataDesc: DataDesc, data: NDArray): Builder = { + this.data = this.data ++ IndexedSeq((dataDesc, data)) + this + } + + /** + * Add one label input with its DataDesc + */ + def addLabelDesc(labelDesc: DataDesc, label: NDArray): Builder = { + this.data = this.data ++ IndexedSeq((labelDesc, label)) this } @@ -317,37 +324,12 @@ object NDArrayIter { this } - /** - * Set the dtype. - * @param dataDType The dtype of the data, default is Float32 - * @param labelDType The dtype of the label, default is Int32 - * @return this - */ - def setDType(dataDType: DType, labelDType: DType): Builder = { - this.dataDType = dataDType - this.labelDType = labelDType - this - } - - /** - * Set the layout. - * @param dataLayout The layout of the data, default is UNDEFINED - * @param labelLayout The layout of the label, default is UNDEFINED - * @return this - */ - def setLayout(dataLayout: String, labelLayout: String): Builder = { - this.dataLayout = dataLayout - this.labelLayout = labelLayout - this - } - /** * Build the NDArrayIter object. * @return the built object. */ def build(): NDArrayIter = { - new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, - dataDType, labelDType, dataLayout, labelLayout) + new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle) } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index f791a3101f36..e59e3706317d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -89,15 +89,15 @@ class PrefetchingIter( } private val _provideLabelDesc: IndexedSeq[DataDesc] = { - if (dataNames == null) { + if (labelNames == null) { iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => acc ++ elem } } else { - iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2)) + iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2)) .map(m => m._1.map(t => - new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout) + new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout) ) ) .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => @@ -178,9 +178,11 @@ class PrefetchingIter( override def getPad(): Int = this.currentBatch.pad // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = this._provideLabel // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = this._provideData // Provide type:DataDesc of the data diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index c86ed4e8f7e2..e840af9395f7 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -134,11 +134,13 @@ class ResizeIter( } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = { dataIter.provideData } // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { dataIter.provideLabel } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 2d55f6ece3e7..2ec6f668dbcc 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -38,11 +38,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "shuffle" -> "1", "flat" -> "1", "silent" -> "0", - "seed" -> "10", - "dataLayout" -> "NT", - "labelLayout" -> "N", - "dataDType" -> "Float32", - "labelDType" -> "Int32" + "seed" -> "10" ) val mnistPack = IO.MNISTPack(params) @@ -61,12 +57,6 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val provideLabel = mnistIter.provideLabel assert(provideData("data") === Shape(100, 784)) assert(provideLabel("label") === Shape(100)) - val provideDataDesc = mnistIter.provideDataDesc - val provideLabelDesc = mnistIter.provideLabelDesc - assert(provideDataDesc(0).dtype == DType.Float32) - assert(provideDataDesc(0).layout == Layout.NT) - assert(provideLabelDesc(0).dtype == DType.Int32) - assert(provideLabelDesc(0).layout == Layout.N) // test_loop mnistIter.reset() batchCount = 0 @@ -303,9 +293,10 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(dataIter2.initLabel == IndexedSeq.empty) // test implementation with DataDesc - val dataIter3 = new NDArrayIter(data, label, 128, false, - "pad", "data", "label", DType.Float32, - DType.Int32, Layout.NTC, Layout.NT) + val dataIter3 = new NDArrayIter( + IO.initDataDesc(data, false, "data", DType.Float32, Layout.NTC), + IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT), + 128, false, "pad") val dataDesc = dataIter3.provideDataDesc val labelDesc = dataIter3.provideLabelDesc assert(dataDesc(0).dtype == DType.Float32) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 3e797a7919ef..825e46596755 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -102,6 +102,7 @@ object ExampleMultiTask { override def getIndex(): IndexedSeq[Long] = this.dataIter.getIndex() // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { val provideLabel = this.dataIter.provideLabel.toArray // Different labels should be used here for actual application @@ -127,6 +128,7 @@ object ExampleMultiTask { override def getPad(): Int = this.dataIter.getPad() // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = this.dataIter.provideData override def provideDataDesc: IndexedSeq[DataDesc] = this.dataIter.provideDataDesc diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index adfe74597479..d4b17074d48c 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -28,9 +28,6 @@ import scala.io.Source import scala.util.Random import scala.collection.mutable -/** - * @author Depeng Liang - */ object BucketIo { type Text2Id = (String, Map[String, Int]) => Array[Int] @@ -94,25 +91,14 @@ object BucketIo { } class BucketSentenceIter( - path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int], - _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], - seperateChar: String, text2Id: Text2Id, - readContent: ReadContent, - dataLayout: String, - labelLayout: String, - dataDType : DType, - labelDType: DType) extends DataIter { - - // scalastyle:off - def this(path: String, vocab: Map[String, Int], buckets: IndexedSeq[Int], - _batchSize: Int, initStates: IndexedSeq[(String, (Int, Int))], - seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, - readContent: ReadContent = defaultReadContent) { - this(path, vocab, buckets, _batchSize, initStates, seperateChar, text2Id, - readContent, Layout.UNDEFINED, Layout.UNDEFINED, DType.Float32, DType.Float32) - } - // scalastyle:on - + path: String, + vocab: Map[String, Int], + var buckets: IndexedSeq[Int], + _batchSize: Int, + private val initStates: IndexedSeq[(String, (Int, Int))], + seperateChar: String = " ", + text2Id: Text2Id = defaultText2Id, + readContent: ReadContent = defaultReadContent) extends DataIter { private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) private val content = readContent(path) @@ -185,13 +171,17 @@ object BucketIo { private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, _defaultBucketKey)) private val _provideDataDesc = { + // TODO: need to allow user to specify DType and Layout val tmp = IndexedSeq(new DataDesc("data", - Shape(_batchSize, _defaultBucketKey), dataDType, dataLayout)) - tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dataDType, dataLayout)) + Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED)) + tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), + DType.Float32, Layout.UNDEFINED)) } - private val _provideLabelDesc = IndexedSeq(new DataDesc("softmax_label", - Shape(_batchSize, _defaultBucketKey), labelDType, labelLayout)) + private val _provideLabelDesc = IndexedSeq( + // TODO: need to allow user to specify DType and Layout + new DataDesc("softmax_label", + Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED)) private var iBucket = 0 @@ -262,9 +252,11 @@ object BucketIo { override def getPad(): Int = 0 // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = this._provideLabel // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = this._provideData // Provide type:DataDesc of the data diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index cf5efc6b9045..bf1b26e4b48d 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -32,11 +32,7 @@ class LabeledPointIter private[mxnet]( private val dimension: Shape, private val _batchSize: Int, private val dataName: String = "data", - private val labelName: String = "label", - private val dataDType: DType = DType.Float32, - private val labelDType: DType = DType.Float32, - private val dataLayout: String = Layout.UNDEFINED, - private val labelLayout: String = Layout.UNDEFINED) extends DataIter { + private val labelName: String = "label") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -119,21 +115,25 @@ class LabeledPointIter private[mxnet]( } // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { ListMap(labelName -> Shape(_batchSize)) } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = { ListMap(dataName -> dataShape) } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dataDType, dataLayout)) + // TODO: need to allow user to specify DType and Layout + IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, Layout.UNDEFINED)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), labelDType, labelLayout)) + // TODO: need to allow user to specify DType and Layout + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), DType.Float32, Layout.UNDEFINED)) } /** diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index 7cc43666c532..a955ee74e7e2 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -32,11 +32,7 @@ class PointIter private[mxnet]( private val dimension: Shape, private val _batchSize: Int, private val dataName: String = "data", - private val labelName: String = "label", - private val dataDType: DType = DType.Float32, - private val labelDType: DType = DType.Float32, - private val dataLayout: String = Layout.UNDEFINED, - private val labelLayout: String = Layout.UNDEFINED) extends DataIter { + private val labelName: String = "label") extends DataIter { private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch] private var index: Int = -1 @@ -118,21 +114,25 @@ class PointIter private[mxnet]( } // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { ListMap(labelName -> Shape(_batchSize)) } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = { ListMap(dataName -> dataShape) } override def provideDataDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(dataName, dataShape, dataDType, dataLayout)) + // TODO: Make DType, Layout configurable + IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, Layout.UNDEFINED)) } override def provideLabelDesc: IndexedSeq[DataDesc] = { - IndexedSeq(new DataDesc(labelName, Shape(_batchSize), labelDType, labelLayout)) + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), + DType.Float32, Layout.UNDEFINED)) } /** From 09db459308c4491c22c3a8180177a9698962a9ed Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 14 Aug 2018 15:21:51 -0700 Subject: [PATCH 22/25] revert change on examples --- .../scala/org/apache/mxnetexamples/customop/Data.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala index 230c56e38678..d61269c131ff 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala @@ -20,7 +20,6 @@ package org.apache.mxnetexamples.customop import org.apache.mxnet.{DataIter, IO, Shape} object Data { - // return train and val iterators for mnist def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): (DataIter, DataIter) = { val flat = if (inputShape.length == 3) "False" else "True" @@ -30,9 +29,7 @@ object Data { "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", "shuffle" -> "True", - "flat" -> flat, - "dataLayout" -> "NT", - "labelLayout" -> "N" + "flat" -> flat ) val trainDataIter = IO.MNISTIter(trainParams) val testParams = Map( @@ -40,9 +37,7 @@ object Data { "label" -> s"$dataPath/t10k-labels-idx1-ubyte", "input_shape" -> inputShape.toString(), "batch_size" -> s"$batchSize", - "flat" -> flat, - "dataLayout" -> "NT", - "labelLayout" -> "N" + "flat" -> flat ) val testDataIter = IO.MNISTIter(testParams) (trainDataIter, testDataIter) From 3f862e47a6fb533fbcc0336556a036070b58637c Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 15 Aug 2018 16:14:57 -0700 Subject: [PATCH 23/25] apply final changes --- .../src/main/scala/org/apache/mxnet/IO.scala | 34 ++++++++++++------- .../org/apache/mxnet/io/NDArrayIter.scala | 26 +++++++------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 77917c261c08..b2ab44b4d8dd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -207,8 +207,8 @@ object DataBatch { private var index: IndexedSeq[Long] = null private var pad: Int = 0 private var bucketKey: AnyRef = null - private var dataShapes: IndexedSeq[DataDesc] = null - private var labelShapes: IndexedSeq[DataDesc] = null + private var dataDesc: IndexedSeq[DataDesc] = null + private var labelDesc: IndexedSeq[DataDesc] = null /** * Set the input data. @@ -262,16 +262,26 @@ object DataBatch { this } + @deprecated + def provideDataShape(name: String, shape: Shape): Builder = { + provideDataDesc(new DataDesc(name, shape)) + } + + @deprecated + def provideLabelShape(name: String, shape: Shape): Builder = { + provideLabelDesc(new DataDesc(name, shape)) + } + /** * Provide the shape of a data. * @param dataDesc DataDescriptor * @return this. */ - def provideDataShape(dataDesc: DataDesc): Builder = { - if (dataShapes == null) { - dataShapes = IndexedSeq(dataDesc) + def provideDataDesc(dataDesc: DataDesc): Builder = { + if (this.dataDesc == null) { + this.dataDesc = IndexedSeq(dataDesc) } else { - dataShapes = dataShapes ++ IndexedSeq(dataDesc) + this.dataDesc = IndexedSeq(dataDesc) } this } @@ -282,18 +292,18 @@ object DataBatch { * @param shape label shape. * @return this. */ - def provideLabelShape(dataDesc: DataDesc): Builder = { - if (labelShapes == null) { - labelShapes = IndexedSeq(dataDesc) + def provideLabelDesc(dataDesc: DataDesc): Builder = { + if (this.labelDesc == null) { + this.labelDesc = IndexedSeq(dataDesc) } else { - labelShapes = labelShapes ++ IndexedSeq(dataDesc) + this.labelDesc = this.labelDesc ++ IndexedSeq(dataDesc) } this } def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, dataShapes, labelShapes) + new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc) } } } @@ -400,7 +410,7 @@ object DataDesc { */ def getBatchAxis(layout: Option[String]): Int = { if (layout.isEmpty|| layout.get == Layout.UNDEFINED) { - logger.warn("Found Undefined Layout, will use default index 0") + logger.warn("Found Undefined Layout, will use default index 0 for batch axis") 0 } else { if (layout.get.contains('N')) { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index ec32f94708ef..e6be0ad02f83 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -46,17 +46,17 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], lastBatchHandle: String) extends DataIter { /** - * @param data Specify the data. Data names will be data_0, data_1, ..., etc. - * @param label Same as data, but is not fed to the model during testing. - * Label names will be label_0, label_1, ..., etc. - * @param dataBatchSize Batch Size - * @param shuffle Whether to shuffle the data - * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch - * - * This iterator will pad, discard or roll over the last batch if - * the size of data does not match batch_size. Roll over is intended - * for training and can cause problems if used for prediction. - */ + * @param data Specify the data. Data names will be data_0, data_1, ..., etc. + * @param label Same as data, but is not fed to the model during testing. + * Label names will be label_0, label_1, ..., etc. + * @param dataBatchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", @@ -291,7 +291,7 @@ object NDArrayIter { /** * Add one data input with its DataDesc */ - def addDataDesc(dataDesc: DataDesc, data: NDArray): Builder = { + def addDataWithDesc(dataDesc: DataDesc, data: NDArray): Builder = { this.data = this.data ++ IndexedSeq((dataDesc, data)) this } @@ -299,7 +299,7 @@ object NDArrayIter { /** * Add one label input with its DataDesc */ - def addLabelDesc(labelDesc: DataDesc, label: NDArray): Builder = { + def addLabelWithDesc(labelDesc: DataDesc, label: NDArray): Builder = { this.data = this.data ++ IndexedSeq((labelDesc, label)) this } From ad280592225d66da14ae0d4c9a0286c045662c5d Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 15 Aug 2018 17:01:25 -0700 Subject: [PATCH 24/25] remove the provideLabelShape --- .../src/main/scala/org/apache/mxnet/IO.scala | 29 ++++--------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index b2ab44b4d8dd..b2003892d748 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -262,42 +262,23 @@ object DataBatch { this } - @deprecated - def provideDataShape(name: String, shape: Shape): Builder = { - provideDataDesc(new DataDesc(name, shape)) - } - - @deprecated - def provideLabelShape(name: String, shape: Shape): Builder = { - provideLabelDesc(new DataDesc(name, shape)) - } - /** * Provide the shape of a data. * @param dataDesc DataDescriptor * @return this. */ - def provideDataDesc(dataDesc: DataDesc): Builder = { - if (this.dataDesc == null) { - this.dataDesc = IndexedSeq(dataDesc) - } else { - this.dataDesc = IndexedSeq(dataDesc) - } + def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = { + this.dataDesc = dataDesc this } /** * Provide the shape of a label. - * @param name label name. - * @param shape label shape. + * @param labelDesc LabelDescriptor * @return this. */ - def provideLabelDesc(dataDesc: DataDesc): Builder = { - if (this.labelDesc == null) { - this.labelDesc = IndexedSeq(dataDesc) - } else { - this.labelDesc = this.labelDesc ++ IndexedSeq(dataDesc) - } + def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = { + this.labelDesc = labelDesc this } From 93329a19cf8a04d7586898fa7ad1d329aa724d6b Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 16 Aug 2018 11:50:43 -0700 Subject: [PATCH 25/25] add TODO about the findings --- scala-package/core/src/main/scala/org/apache/mxnet/IO.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index b2003892d748..a1095cf04833 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -144,6 +144,9 @@ class DataBatch(val data: IndexedSeq[NDArray], // (must match the order of input data/label) private val providedDataDesc: IndexedSeq[DataDesc], private val providedLabelDesc: IndexedSeq[DataDesc]) { + // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)] + // However, since the data and label can be accessed publicly (no getter and setter) + // the change on this will break BC def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], index: IndexedSeq[Long],