Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
apply final changes
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Aug 15, 2018
1 parent aea880e commit ca3b664
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
34 changes: 22 additions & 12 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ object DataBatch {
private var index: IndexedSeq[Long] = null
private var pad: Int = 0
private var bucketKey: AnyRef = null
private var dataShapes: IndexedSeq[DataDesc] = null
private var labelShapes: IndexedSeq[DataDesc] = null
private var dataDesc: IndexedSeq[DataDesc] = null
private var labelDesc: IndexedSeq[DataDesc] = null

/**
* Set the input data.
Expand Down Expand Up @@ -262,16 +262,26 @@ object DataBatch {
this
}

@deprecated
def provideDataShape(name: String, shape: Shape): Builder = {
provideDataDesc(new DataDesc(name, shape))
}

@deprecated
def provideLabelShape(name: String, shape: Shape): Builder = {
provideLabelDesc(new DataDesc(name, shape))
}

/**
* Provide the shape of a data.
* @param dataDesc DataDescriptor
* @return this.
*/
def provideDataShape(dataDesc: DataDesc): Builder = {
if (dataShapes == null) {
dataShapes = IndexedSeq(dataDesc)
def provideDataDesc(dataDesc: DataDesc): Builder = {
if (this.dataDesc == null) {
this.dataDesc = IndexedSeq(dataDesc)
} else {
dataShapes = dataShapes ++ IndexedSeq(dataDesc)
this.dataDesc = IndexedSeq(dataDesc)
}
this
}
Expand All @@ -282,18 +292,18 @@ object DataBatch {
* @param shape label shape.
* @return this.
*/
def provideLabelShape(dataDesc: DataDesc): Builder = {
if (labelShapes == null) {
labelShapes = IndexedSeq(dataDesc)
def provideLabelDesc(dataDesc: DataDesc): Builder = {
if (this.labelDesc == null) {
this.labelDesc = IndexedSeq(dataDesc)
} else {
labelShapes = labelShapes ++ IndexedSeq(dataDesc)
this.labelDesc = this.labelDesc ++ IndexedSeq(dataDesc)
}
this
}

def build(): DataBatch = {
require(data != null, "data is required.")
new DataBatch(data, label, index, pad, bucketKey, dataShapes, labelShapes)
new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc)
}
}
}
Expand Down Expand Up @@ -400,7 +410,7 @@ object DataDesc {
*/
def getBatchAxis(layout: Option[String]): Int = {
if (layout.isEmpty|| layout.get == Layout.UNDEFINED) {
logger.warn("Found Undefined Layout, will use default index 0")
logger.warn("Found Undefined Layout, will use default index 0 for batch axis")
0
} else {
if (layout.get.contains('N')) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
lastBatchHandle: String) extends DataIter {

/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
Expand Down Expand Up @@ -291,15 +291,15 @@ object NDArrayIter {
/**
* Add one data input with its DataDesc
*/
def addDataDesc(dataDesc: DataDesc, data: NDArray): Builder = {
def addDataWithDesc(dataDesc: DataDesc, data: NDArray): Builder = {
this.data = this.data ++ IndexedSeq((dataDesc, data))
this
}

/**
* Add one label input with its DataDesc
*/
def addLabelDesc(labelDesc: DataDesc, label: NDArray): Builder = {
def addLabelWithDesc(labelDesc: DataDesc, label: NDArray): Builder = {
this.data = this.data ++ IndexedSeq((labelDesc, label))
this
}
Expand Down

0 comments on commit ca3b664

Please sign in to comment.