Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve NDArrayIter to have Builder and ability to specifying names #3

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import scala.collection.immutable.ListMap
/**
* NDArrayIter object in mxnet. Taking NDArray to get dataiter.
*
* @param data NDArrayIter supports single or multiple data and label.
* @param data Specify the data as well as the name.
* NDArrayIter supports single or multiple data and label.
* @param label Same as data, but is not fed to the model during testing.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
Expand All @@ -38,15 +39,35 @@ import scala.collection.immutable.ListMap
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
private val dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
class NDArrayIter(data: IndexedSeq[(String, NDArray)],
label: IndexedSeq[(String, NDArray)],
private val dataBatchSize: Int, shuffle: Boolean,
lastBatchHandle: String) extends DataIter {

/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") {
this(IO.initData(data, allowEmpty = false, dataName),
IO.initData(label, allowEmpty = true, labelName),
dataBatchSize, shuffle, lastBatchHandle)
}

private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])

private val (_dataList: IndexedSeq[NDArray],
_labelList: IndexedSeq[NDArray]) = {
val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = {
// data should not be null and size > 0
require(data != null && data.size > 0,
"data should not be null and data.size should not be zero")
Expand All @@ -55,17 +76,17 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
"label should not be null. Use IndexedSeq.empty if there are no labels")

// shuffle is not supported currently
require(shuffle == false, "shuffle is not supported currently")
require(!shuffle, "shuffle is not supported currently")

// discard final part if lastBatchHandle equals discard
if (lastBatchHandle.equals("discard")) {
val dataSize = data(0).shape(0)
val dataSize = data(0)._2.shape(0)
require(dataBatchSize <= dataSize,
"batch_size need to be smaller than data size when not padding.")
val keepSize = dataSize - dataSize % dataBatchSize
val dataList = data.map(ndArray => {ndArray.slice(0, keepSize)})
val dataList = data.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
if (!label.isEmpty) {
val labelList = label.map(ndArray => {ndArray.slice(0, keepSize)})
val labelList = label.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
(dataList, labelList)
} else {
(dataList, label)
Expand All @@ -75,13 +96,9 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
}
}


val initData: IndexedSeq[(String, NDArray)] = IO.initData(_dataList, false, dataName)
val initLabel: IndexedSeq[(String, NDArray)] = IO.initData(_labelList, true, labelName)
val numData = _dataList(0).shape(0)
val numSource = initData.size
var cursor = -dataBatchSize

val numData = initData(0)._2.shape(0)
val numSource: MXUint = initData.size
private var cursor = -dataBatchSize

private val (_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape]) = {
Expand Down Expand Up @@ -112,8 +129,8 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* reset the iterator
*/
override def reset(): Unit = {
if (lastBatchHandle.equals("roll_over") && cursor>numData) {
cursor = -dataBatchSize + (cursor%numData)%dataBatchSize
if (lastBatchHandle.equals("roll_over") && cursor > numData) {
cursor = -dataBatchSize + (cursor%numData) % dataBatchSize
} else {
cursor = -dataBatchSize
}
Expand Down Expand Up @@ -154,16 +171,16 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
newArray
}

private def _getData(data: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = {
require(cursor < numData, "DataIter needs reset.")
if (data == null) {
null
} else {
if (cursor + dataBatchSize <= numData) {
data.map(ndArray => {ndArray.slice(cursor, cursor + dataBatchSize)}).toIndexedSeq
data.map { case (_, ndArray) => ndArray.slice(cursor, cursor + dataBatchSize) }
} else {
// padding
data.map(_padData).toIndexedSeq
data.map { case (_, ndArray) => _padData(ndArray) }
}
}
}
Expand All @@ -173,23 +190,23 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = {
_getData(_dataList)
_getData(initData)
}

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = {
_getData(_labelList)
_getData(initLabel)
}

/**
* the index of current batch
* @return
*/
override def getIndex(): IndexedSeq[Long] = {
(cursor.toLong to (cursor + dataBatchSize).toLong).toIndexedSeq
cursor.toLong to (cursor + dataBatchSize).toLong
}

/**
Expand All @@ -213,3 +230,66 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index

override def batchSize: Int = dataBatchSize
}

object NDArrayIter {

/**
* Builder class for NDArrayIter.
*/
class Builder() {
private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
private var dataBatchSize: Int = 1
private var lastBatchHandle: String = "pad"

/**
* Add one data input with its name.
* @param name Data name.
* @param data Data nd-array.
* @return The builder object itself.
*/
def addData(name: String, data: NDArray): Builder = {
this.data = this.data ++ IndexedSeq((name, data))
this
}

/**
* Add one label input with its name.
* @param name Label name.
* @param label Label nd-array.
* @return The builder object itself.
*/
def addLabel(name: String, label: NDArray): Builder = {
this.label = this.label ++ IndexedSeq((name, label))
this
}

/**
* Set the batch size of the iterator.
* @param batchSize batch size.
* @return The builder object itself.
*/
def setBatchSize(batchSize: Int): Builder = {
this.dataBatchSize = batchSize
this
}

/**
* How to handle the last batch.
* @param lastBatchHandle Can be "pad", "discard" or "roll_over".
* @return The builder object itself.
*/
def setLastBatchHandle(lastBatchHandle: String): Builder = {
this.lastBatchHandle = lastBatchHandle
this
}

/**
* Build the NDArrayIter object.
* @return the built object.
*/
def build(): NDArrayIter = {
new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.sys.process._

class IOSuite extends FunSuite with BeforeAndAfterAll {

private var tu = new TestUtil
private val tu = new TestUtil

test("test MNISTIter & MNISTPack") {
// get data
Expand Down Expand Up @@ -258,7 +258,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(batchCount === nBatch0)

// test discard
val dataIter1 = new NDArrayIter(data, label, 128, false, "discard")
val dataIter1 = new NDArrayIter.Builder()
.addData("data0", data(0)).addData("data1", data(1))
.addLabel("label", label(0))
.setBatchSize(128)
.setLastBatchHandle("discard").build()
val nBatch1 = 7
batchCount = 0
while(dataIter1.hasNext) {
Expand Down