From 02f312567857bd2eaecaf41d44d0d44a7297ef54 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 25 Apr 2019 15:53:06 -0700 Subject: [PATCH 1/3] add fix in the code --- .../org/apache/mxnet/module/BaseModule.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index 3be8e060fd6f..39da3811f883 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -247,6 +247,18 @@ abstract class BaseModule { /** * Run prediction and collect the outputs. + * The concatenation process will be like + * {{{ + * outputBatches = [ + * [a1, a2, a3], // batch a + * [b1, b2, b3] // batch b + * ] + * result = [ + * NDArray, // [a1, b1] + * NDArray, // [a2, b2] + * NDArray, // [a3, b3] + * ] + * }}} * @param evalData * @param numBatch Default is -1, indicating running all the batches in the data iterator. * @param reset Default is `True`, indicating whether we should reset the data iter before start @@ -264,7 +276,8 @@ abstract class BaseModule { s"in mini-batches (${out.size})." + "Maybe bucketing is used?") ) - val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out)) + val oBT = outputBatches.transpose + val concatenatedOutput = oBT.map(out => NDArray.concatenate(out)) outputBatches.foreach(_.foreach(_.dispose())) concatenatedOutput } From e7ac2fb990db39c9ff725d256eb9f8809ed17d91 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 26 Apr 2019 15:27:41 -0700 Subject: [PATCH 2/3] add unit test --- .../scala/org/apache/mxnet/ModuleSuite.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 3e753a18d247..5aed01bde693 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._ import org.apache.mxnet.io._ class ModuleSuite extends FunSuite with BeforeAndAfterAll { + + class myModule(symbol : Symbol) extends Module (symbol) { + override def predictEveryBatch(evalData: DataIter, + numBatch: Int = 1, reset: Boolean = true): + IndexedSeq[IndexedSeq[NDArray]] = { + val data = IndexedSeq( + NDArray.ones(Shape(1, 10, 1)), + NDArray.ones(Shape(1, 10, 1)), + NDArray.ones(Shape(1, 10, 4)) + ) + List.fill(numBatch)(data).toIndexedSeq + } + } + + test("predict") { + val sym = Symbol.Variable("data") + val mod = new myModule(sym) + val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1))) + var output = mod.predict(dummyIter, 1) + require(output(0).shape == Shape(1, 10, 1)) + require(output(1).shape == Shape(1, 10, 1)) + require(output(2).shape == Shape(1, 10, 4)) + output = mod.predict(dummyIter, 2) + require(output(0).shape == Shape(2, 10, 1)) + require(output(1).shape == Shape(2, 10, 1)) + require(output(2).shape == Shape(2, 10, 4)) + } + test ("model dtype") { val dType = DType.Float32 val dShape = Shape(3, 8, 7) From 93256ef40e8ceb99f581f8797c191e36fa8a2523 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 26 Apr 2019 15:33:01 -0700 Subject: [PATCH 3/3] update comments --- .../org/apache/mxnet/module/BaseModule.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index 39da3811f883..7fbdae5b3e21 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -247,23 +247,23 @@ abstract class BaseModule { /** * Run prediction and collect the outputs. - * The concatenation process will be like - * {{{ - * outputBatches = [ - * [a1, a2, a3], // batch a - * [b1, b2, b3] // batch b - * ] - * result = [ - * NDArray, // [a1, b1] - * NDArray, // [a2, b2] - * NDArray, // [a3, b3] - * ] - * }}} - * @param evalData + * @param evalData dataIter to do the Inference * @param numBatch Default is -1, indicating running all the batches in the data iterator. * @param reset Default is `True`, indicating whether we should reset the data iter before start * doing prediction. * @return The return value will be a list `[out1, out2, out3]`. + * The concatenation process will be like + * {{{ + * outputBatches = [ + * [a1, a2, a3], // batch a + * [b1, b2, b3] // batch b + * ] + * result = [ + * NDArray, // [a1, b1] + * NDArray, // [a2, b2] + * NDArray, // [a3, b3] + * ] + * }}} * Where each element is concatenation of the outputs for all the mini-batches. */ def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)