diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index e690abba0d13..b205bbe47abb 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -63,7 +63,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], dataName: String = "data", labelName: String = "label") { this(IO.initDataDesc(data, allowEmpty = false, dataName, if (data == null || data.isEmpty) MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED), - IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED), + IO.initDataDesc(label, allowEmpty = true, labelName, + if (label == null || label.isEmpty) MX_REAL_TYPE else label(0).dtype, Layout.UNDEFINED), dataBatchSize, shuffle, lastBatchHandle) } @@ -175,7 +176,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], private def _padData(ndArray: NDArray): NDArray = { val padNum = cursor + dataBatchSize - numData val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size) - val newArray = NDArray.zeros(shape) + // The new NDArray has to be created such that it inherits dtype from the passed in array + val newArray = NDArray.zeros(shape, dtype = ndArray.dtype) NDArrayCollector.auto().withScope { val batch = ndArray.slice(cursor, numData) val padding = ndArray.slice(0, padNum) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index d3969b0ce77d..698a2b53a9fa 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -237,7 +237,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val shape0 = Shape(Array(1000, 2, 2)) val data = IndexedSeq(NDArray.ones(shape0), NDArray.zeros(shape0)) val shape1 = Shape(Array(1000, 1)) - val label = IndexedSeq(NDArray.ones(shape1)) + val label = IndexedSeq(NDArray.ones(shape1, dtype = DType.Int32)) val batchData0 = NDArray.ones(Shape(Array(128, 2, 2))) val batchData1 = NDArray.zeros(Shape(Array(128, 2, 2))) val batchLabel = NDArray.ones(Shape(Array(128, 1))) @@ -254,6 +254,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(tBatch.data(0).toArray === batchData0.toArray) assert(tBatch.data(1).toArray === batchData1.toArray) assert(tBatch.label(0).toArray === batchLabel.toArray) + assert(tBatch.label(0).dtype == DType.Int32) } assert(batchCount === nBatch0)