diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 77917c261c08..b2ab44b4d8dd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -207,8 +207,8 @@ object DataBatch { private var index: IndexedSeq[Long] = null private var pad: Int = 0 private var bucketKey: AnyRef = null - private var dataShapes: IndexedSeq[DataDesc] = null - private var labelShapes: IndexedSeq[DataDesc] = null + private var dataDesc: IndexedSeq[DataDesc] = null + private var labelDesc: IndexedSeq[DataDesc] = null /** * Set the input data. @@ -262,16 +262,26 @@ object DataBatch { this } + @deprecated + def provideDataShape(name: String, shape: Shape): Builder = { + provideDataDesc(new DataDesc(name, shape)) + } + + @deprecated + def provideLabelShape(name: String, shape: Shape): Builder = { + provideLabelDesc(new DataDesc(name, shape)) + } + /** * Provide the shape of a data. * @param dataDesc DataDescriptor * @return this. */ - def provideDataShape(dataDesc: DataDesc): Builder = { - if (dataShapes == null) { - dataShapes = IndexedSeq(dataDesc) + def provideDataDesc(dataDesc: DataDesc): Builder = { + if (this.dataDesc == null) { + this.dataDesc = IndexedSeq(dataDesc) } else { - dataShapes = dataShapes ++ IndexedSeq(dataDesc) + this.dataDesc = IndexedSeq(dataDesc) } this } @@ -282,18 +292,18 @@ object DataBatch { * @param shape label shape. * @return this. */ - def provideLabelShape(dataDesc: DataDesc): Builder = { - if (labelShapes == null) { - labelShapes = IndexedSeq(dataDesc) + def provideLabelDesc(dataDesc: DataDesc): Builder = { + if (this.labelDesc == null) { + this.labelDesc = IndexedSeq(dataDesc) } else { - labelShapes = labelShapes ++ IndexedSeq(dataDesc) + this.labelDesc = this.labelDesc ++ IndexedSeq(dataDesc) } this } def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, dataShapes, labelShapes) + new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc) } } } @@ -400,7 +410,7 @@ object DataDesc { */ def getBatchAxis(layout: Option[String]): Int = { if (layout.isEmpty|| layout.get == Layout.UNDEFINED) { - logger.warn("Found Undefined Layout, will use default index 0") + logger.warn("Found Undefined Layout, will use default index 0 for batch axis") 0 } else { if (layout.get.contains('N')) { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index ec32f94708ef..e6be0ad02f83 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -46,17 +46,17 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], lastBatchHandle: String) extends DataIter { /** - * @param data Specify the data. Data names will be data_0, data_1, ..., etc. - * @param label Same as data, but is not fed to the model during testing. - * Label names will be label_0, label_1, ..., etc. - * @param dataBatchSize Batch Size - * @param shuffle Whether to shuffle the data - * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch - * - * This iterator will pad, discard or roll over the last batch if - * the size of data does not match batch_size. Roll over is intended - * for training and can cause problems if used for prediction. - */ + * @param data Specify the data. Data names will be data_0, data_1, ..., etc. + * @param label Same as data, but is not fed to the model during testing. + * Label names will be label_0, label_1, ..., etc. + * @param dataBatchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", @@ -291,7 +291,7 @@ object NDArrayIter { /** * Add one data input with its DataDesc */ - def addDataDesc(dataDesc: DataDesc, data: NDArray): Builder = { + def addDataWithDesc(dataDesc: DataDesc, data: NDArray): Builder = { this.data = this.data ++ IndexedSeq((dataDesc, data)) this } @@ -299,7 +299,7 @@ object NDArrayIter { /** * Add one label input with its DataDesc */ - def addLabelDesc(labelDesc: DataDesc, label: NDArray): Builder = { + def addLabelWithDesc(labelDesc: DataDesc, label: NDArray): Builder = { this.data = this.data ++ IndexedSeq((labelDesc, label)) this }