Skip to content

Commit

Permalink
Now passing DType of Label downstream to Label's DataDesc object (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushghai authored and stephenrawls committed Feb 16, 2019
1 parent 4604584 commit 6068359
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
dataName: String = "data", labelName: String = "label") {
this(IO.initDataDesc(data, allowEmpty = false, dataName,
if (data == null || data.isEmpty) MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
IO.initDataDesc(label, allowEmpty = true, labelName,
if (label == null || label.isEmpty) MX_REAL_TYPE else label(0).dtype, Layout.UNDEFINED),
dataBatchSize, shuffle, lastBatchHandle)
}

Expand Down Expand Up @@ -175,7 +176,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
private def _padData(ndArray: NDArray): NDArray = {
val padNum = cursor + dataBatchSize - numData
val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size)
val newArray = NDArray.zeros(shape)
// The new NDArray has to be created such that it inherits dtype from the passed in array
val newArray = NDArray.zeros(shape, dtype = ndArray.dtype)
NDArrayCollector.auto().withScope {
val batch = ndArray.slice(cursor, numData)
val padding = ndArray.slice(0, padNum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
val shape0 = Shape(Array(1000, 2, 2))
val data = IndexedSeq(NDArray.ones(shape0), NDArray.zeros(shape0))
val shape1 = Shape(Array(1000, 1))
val label = IndexedSeq(NDArray.ones(shape1))
val label = IndexedSeq(NDArray.ones(shape1, dtype = DType.Int32))
val batchData0 = NDArray.ones(Shape(Array(128, 2, 2)))
val batchData1 = NDArray.zeros(Shape(Array(128, 2, 2)))
val batchLabel = NDArray.ones(Shape(Array(128, 1)))
Expand All @@ -254,6 +254,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(tBatch.data(0).toArray === batchData0.toArray)
assert(tBatch.data(1).toArray === batchData1.toArray)
assert(tBatch.label(0).toArray === batchLabel.toArray)
assert(tBatch.label(0).dtype == DType.Int32)
}

assert(batchCount === nBatch0)
Expand Down

0 comments on commit 6068359

Please sign in to comment.