diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index 3be8e060fd6f..7fbdae5b3e21 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -247,11 +247,23 @@ abstract class BaseModule { /** * Run prediction and collect the outputs. - * @param evalData + * @param evalData dataIter to do the Inference * @param numBatch Default is -1, indicating running all the batches in the data iterator. * @param reset Default is `True`, indicating whether we should reset the data iter before start * doing prediction. * @return The return value will be a list `[out1, out2, out3]`. + * The concatenation process will be like + * {{{ + * outputBatches = [ + * [a1, a2, a3], // batch a + * [b1, b2, b3] // batch b + * ] + * result = [ + * NDArray, // [a1, b1] + * NDArray, // [a2, b2] + * NDArray, // [a3, b3] + * ] + * }}} * Where each element is concatenation of the outputs for all the mini-batches. */ def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true) @@ -264,7 +276,8 @@ abstract class BaseModule { s"in mini-batches (${out.size})." + "Maybe bucketing is used?") ) - val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out)) + val oBT = outputBatches.transpose + val concatenatedOutput = oBT.map(out => NDArray.concatenate(out)) outputBatches.foreach(_.foreach(_.dispose())) concatenatedOutput } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 3e753a18d247..5aed01bde693 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._ import org.apache.mxnet.io._ class ModuleSuite extends FunSuite with BeforeAndAfterAll { + + class myModule(symbol : Symbol) extends Module (symbol) { + override def predictEveryBatch(evalData: DataIter, + numBatch: Int = 1, reset: Boolean = true): + IndexedSeq[IndexedSeq[NDArray]] = { + val data = IndexedSeq( + NDArray.ones(Shape(1, 10, 1)), + NDArray.ones(Shape(1, 10, 1)), + NDArray.ones(Shape(1, 10, 4)) + ) + List.fill(numBatch)(data).toIndexedSeq + } + } + + test("predict") { + val sym = Symbol.Variable("data") + val mod = new myModule(sym) + val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1))) + var output = mod.predict(dummyIter, 1) + require(output(0).shape == Shape(1, 10, 1)) + require(output(1).shape == Shape(1, 10, 1)) + require(output(2).shape == Shape(1, 10, 4)) + output = mod.predict(dummyIter, 2) + require(output(0).shape == Shape(2, 10, 1)) + require(output(1).shape == Shape(2, 10, 1)) + require(output(2).shape == Shape(2, 10, 4)) + } + test ("model dtype") { val dType = DType.Float32 val dShape = Shape(3, 8, 7)