Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add changes with dataLayout and labelLayout
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 24, 2018
1 parent b32e5e9 commit a779124
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 67 deletions.
28 changes: 18 additions & 10 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ object IO {
checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out))
val dataName = params.getOrElse("data_name", "data")
val labelName = params.getOrElse("label_name", "label")
new MXDataIter(out.value, dataName, labelName)
val dataLayout = params.getOrElse("dataLayout", "NCHW")
val labelLayout = params.getOrElse("labelLayout", "N")
new MXDataIter(out.value, dataName, labelName,
dataLayout = dataLayout, labelLayout = labelLayout)
}

// Convert data into canonical form.
Expand Down Expand Up @@ -142,7 +145,8 @@ class DataBatch(val data: IndexedSeq[NDArray],
private val providedData: ListMap[String, Shape] = null,
private val providedLabel: ListMap[String, Shape] = null,
val dtype: DType = Base.MX_REAL_TYPE,
val layout: String = "NCHW") {
val dataLayout: String = "NCHW",
val labelLayout: String = "N") {
/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
Expand Down Expand Up @@ -172,7 +176,8 @@ object DataBatch {
private var label: IndexedSeq[NDArray] = null
private var index: IndexedSeq[Long] = null
private var pad: Int = 0
private var layout: String = "NCHW"
private var dataLayout: String = "NCHW"
private var labelLayout: String = "N"
private var dtype: DType = Base.MX_REAL_TYPE
private var bucketKey: AnyRef = null
private var datatShapes: ListMap[String, Shape] = null
Expand Down Expand Up @@ -232,11 +237,13 @@ object DataBatch {

/**
* Set the layout.
* @param layout The layout of the label, default is NCHW
* @param dataLayout The layout of the data, default is NCHW
* @param labelLayout The layout of the label, default is N
* @return this
*/
def setLayout(layout: String): Builder = {
this.layout = layout
def setLayout(dataLayout: String, labelLayout: String): Builder = {
this.dataLayout = dataLayout
this.labelLayout = labelLayout
this
}

Expand Down Expand Up @@ -282,7 +289,8 @@ object DataBatch {

def build(): DataBatch = {
require(data != null, "data is required.")
new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes, dtype, layout)
new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes,
dtype, dataLayout, labelLayout)
}
}
}
Expand All @@ -305,7 +313,7 @@ abstract class DataIter extends Iterator[DataBatch] {
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
dtype = getDType(), layout = getLayout())
dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2)
}

/**
Expand Down Expand Up @@ -335,9 +343,9 @@ abstract class DataIter extends Iterator[DataBatch] {

/**
* Get the layout
* @return layout of the DataIter
* @return data and label layout of the DataIter
*/
def getLayout(): String
def getLayout(): (String, String)

/**
* Get the index of current batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ import scala.collection.mutable.ListBuffer
*/
private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
dataName: String = "data",
labelName: String = "label")
labelName: String = "label",
dtype: DType = DType.Float32,
dataLayout: String = "NCHW",
labelLayout: String = "N")
extends DataIter with WarnIfNotDisposed {

private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
Expand Down Expand Up @@ -65,10 +68,11 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
val data = currentBatch.data(0)
val label = currentBatch.label(0)
val dType = currentBatch.dtype
val layout = currentBatch.layout
val dataLayout = currentBatch.dataLayout
val labelLayout = currentBatch.labelLayout
// properties
val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, layout)),
IndexedSeq(new DataDesc(labelName, label.shape, dType, layout)),
val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, dataLayout)),
IndexedSeq(new DataDesc(labelName, label.shape, dType, labelLayout)),
data.shape(0))
currentBatch.dispose()
reset()
Expand Down Expand Up @@ -126,7 +130,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
index = getIndex(), pad = getPad(),
dtype = currentBatch.dtype, layout = currentBatch.layout)
dtype = getDType(), dataLayout = getLayout()._1,
labelLayout = getLayout()._2)
} else {
currentBatch = null
}
Expand Down Expand Up @@ -179,17 +184,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
* Get the DType
* @return DType
*/
def getDType(): DType = {
currentBatch.dtype
}
def getDType(): DType = dtype

/**
* Get the layout
* @return layout
*/
def getLayout(): String = {
currentBatch.layout
}
def getLayout(): (String, String) = (dataLayout, labelLayout)

// The name and shape of data provided by this iterator
override def provideData: ListMap[String, Shape] = _provideData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ import scala.collection.immutable.ListMap
class NDArrayIter(data: IndexedSeq[(String, NDArray)],
label: IndexedSeq[(String, NDArray)],
private val dataBatchSize: Int, shuffle: Boolean,
lastBatchHandle: String, dtype: DType, layout: String) extends DataIter {
lastBatchHandle: String,
dtype: DType, dataLayout: String, labelLayout: String) extends DataIter {

/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
Expand All @@ -61,10 +62,11 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label",
dType: DType = MX_REAL_TYPE, layout: String = "NCHW") {
dType: DType = MX_REAL_TYPE, dataLayout: String = "NCHW",
labelLayout: String = "N") {
this(IO.initData(data, allowEmpty = false, dataName),
IO.initData(label, allowEmpty = true, labelName),
dataBatchSize, shuffle, lastBatchHandle, dType, layout)
dataBatchSize, shuffle, lastBatchHandle, dType, dataLayout, labelLayout)
}

private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
Expand Down Expand Up @@ -111,8 +113,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],

private val (_provideDataDesc: IndexedSeq[DataDesc],
_provideLabelDesc: IndexedSeq[DataDesc]) = {
val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout))
val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, layout))
val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, dataLayout))
val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, labelLayout))
(pData, pLabel)
}

Expand Down Expand Up @@ -158,7 +160,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
if (hasNext) {
cursor += dataBatchSize
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
dtype = getDType(), layout = getLayout())
dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2)
} else {
throw new NoSuchElementException
}
Expand Down Expand Up @@ -245,8 +247,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
* Get the layout
* @return layout
*/
def getLayout(): String = {
layout
def getLayout(): (String, String) = {
(dataLayout, labelLayout)
}

// The name and shape of data provided by this iterator
Expand Down Expand Up @@ -274,7 +276,8 @@ object NDArrayIter {
private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
private var dataBatchSize: Int = 1
private var lastBatchHandle: String = "pad"
private var layout: String = "NCHW"
private var dataLayout: String = "NCHW"
private var labelLayout: String = "N"
private var dtype: DType = Base.MX_REAL_TYPE

/**
Expand Down Expand Up @@ -331,11 +334,13 @@ object NDArrayIter {

/**
* Set the layout.
* @param layout The layout of the label, default is NCHW
* @param dataLayout The layout of the data, default is NCHW
* @param labelLayout The layout of the label, default is N
* @return this
*/
def setLayout(layout: String): Builder = {
this.layout = layout
def setLayout(dataLayout: String, labelLayout: String): Builder = {
this.dataLayout = dataLayout
this.labelLayout = labelLayout
this
}

Expand All @@ -344,7 +349,8 @@ object NDArrayIter {
* @return the built object.
*/
def build(): NDArrayIter = {
new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle, dtype, layout)
new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle,
dtype, dataLayout, labelLayout)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ class PrefetchingIter(
* Get the layout
* @return layout
*/
def getLayout(): String = {
currentBatch.layout
def getLayout(): (String, String) = {
(currentBatch.dataLayout, currentBatch.labelLayout)
}

// The name and shape of label provided by this iterator
Expand Down Expand Up @@ -224,7 +224,8 @@ class PrefetchingIter(
labels.toIndexedSeq.flatten,
nextBatch(0).index,
nextBatch(0).pad,
layout = nextBatch(0).layout,
dataLayout = nextBatch(0).dataLayout,
labelLayout = nextBatch(0).labelLayout,
dtype = nextBatch(0).dtype)
for (e <- dataTaken) e.release()
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ class ResizeIter(
* Get the layout
* @return layout
*/
def getLayout(): String = {
currentBatch.layout
def getLayout(): (String, String) = {
(currentBatch.dataLayout, currentBatch.labelLayout)
}

override def batchSize: Int = {
Expand Down
23 changes: 17 additions & 6 deletions scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
"shuffle" -> "1",
"flat" -> "1",
"silent" -> "0",
"seed" -> "10"
"seed" -> "10",
"dataLayout" -> "NT",
"labelLayout" -> "N"
)

val mnistPack = IO.MNISTPack(params)
Expand Down Expand Up @@ -99,7 +101,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
"data_shape" -> "(3,28,28)",
"batch_size" -> "100",
"preprocess_threads" -> "4",
"prefetch_buffer" -> "1"
"prefetch_buffer" -> "1",
"dataLayout" -> "NCHW",
"labelLayout" -> "N"
)
val imgRecIter = IO.ImageRecordIter(params)
val nBatch = 500
Expand Down Expand Up @@ -145,7 +149,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
"shuffle" -> "1",
"flat" -> "1",
"silent" -> "0",
"seed" -> "10"
"seed" -> "10",
"dataLayout" -> "NT",
"labelLayout" -> "N"
)

val mnistIter = IO.MNISTIter(params)
Expand Down Expand Up @@ -182,7 +188,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
"shuffle" -> "1",
"flat" -> "1",
"silent" -> "0",
"seed" -> "10"
"seed" -> "10",
"dataLayout" -> "NT",
"labelLayout" -> "N"
)

val mnistPack1 = IO.MNISTPack(params)
Expand Down Expand Up @@ -243,7 +251,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
val batchLabel = NDArray.ones(Shape(Array(128, 1)))

// test pad
val dataIter0 = new NDArrayIter(data, label, 128, false, "pad")
val dataIter0 = new NDArrayIter(data, label, 128, false, "pad",
dataLayout = "NTC", labelLayout = "NT")
var batchCount = 0
val nBatch0 = 8
while(dataIter0.hasNext) {
Expand All @@ -262,6 +271,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
.addData("data0", data(0)).addData("data1", data(1))
.addLabel("label", label(0))
.setBatchSize(128)
.setLayout("NTC", "NT")
.setLastBatchHandle("discard").build()
val nBatch1 = 7
batchCount = 0
Expand All @@ -277,7 +287,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(batchCount === nBatch1)

// test empty label (for prediction)
val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard")
val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard",
dataLayout = "NTC")
batchCount = 0
while(dataIter2.hasNext) {
val tBatch = dataIter2.next()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label",
dataLayout = "NCHW", labelLayout = "NCHW")

// symbols
var x = Symbol.Variable("data")
Expand Down Expand Up @@ -234,7 +235,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label",
dataLayout = "NCHW", labelLayout = "NCHW")

// symbols
var x = Symbol.Variable("data")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ object ExampleMultiTask {
new DataBatch(batch.data,
IndexedSeq(label, label),
batch.index,
batch.pad, dtype = batch.dtype, layout = batch.layout)
batch.pad, dtype = batch.dtype, dataLayout = batch.dataLayout,
labelLayout = batch.labelLayout)
} else {
throw new NoSuchElementException
}
Expand Down Expand Up @@ -129,7 +130,7 @@ object ExampleMultiTask {

override def getDType(): DType = this.dataIter.getDType()

override def getLayout(): String = this.dataIter.getLayout()
override def getLayout(): (String, String) = this.dataIter.getLayout()

// The name and shape of data provided by this iterator
override def provideData: ListMap[String, Shape] = this.dataIter.provideData
Expand Down
Loading

0 comments on commit a779124

Please sign in to comment.