Skip to content

Commit

Permalink
Scala/Java Predict API fix apache#14756 (apache#14804)
Browse files Browse the repository at this point in the history
* add fix in the code

* add unit test

* update comments
  • Loading branch information
lanking520 authored and haohuw committed Jun 23, 2019
1 parent de374b0 commit f714b21
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,23 @@ abstract class BaseModule {

/**
* Run prediction and collect the outputs.
* @param evalData
* @param evalData dataIter to do the Inference
* @param numBatch Default is -1, indicating running all the batches in the data iterator.
* @param reset Default is `True`, indicating whether we should reset the data iter before start
* doing prediction.
* @return The return value will be a list `[out1, out2, out3]`.
* The concatenation process will be like
* {{{
* outputBatches = [
* [a1, a2, a3], // batch a
* [b1, b2, b3] // batch b
* ]
* result = [
* NDArray, // [a1, b1]
* NDArray, // [a2, b2]
* NDArray, // [a3, b3]
* ]
* }}}
* Where each element is concatenation of the outputs for all the mini-batches.
*/
def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
Expand All @@ -264,7 +276,8 @@ abstract class BaseModule {
s"in mini-batches (${out.size})." +
"Maybe bucketing is used?")
)
val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
val oBT = outputBatches.transpose
val concatenatedOutput = oBT.map(out => NDArray.concatenate(out))
outputBatches.foreach(_.foreach(_.dispose()))
concatenatedOutput
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._
import org.apache.mxnet.io._

class ModuleSuite extends FunSuite with BeforeAndAfterAll {

class myModule(symbol : Symbol) extends Module (symbol) {
override def predictEveryBatch(evalData: DataIter,
numBatch: Int = 1, reset: Boolean = true):
IndexedSeq[IndexedSeq[NDArray]] = {
val data = IndexedSeq(
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 4))
)
List.fill(numBatch)(data).toIndexedSeq
}
}

test("predict") {
val sym = Symbol.Variable("data")
val mod = new myModule(sym)
val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1)))
var output = mod.predict(dummyIter, 1)
require(output(0).shape == Shape(1, 10, 1))
require(output(1).shape == Shape(1, 10, 1))
require(output(2).shape == Shape(1, 10, 4))
output = mod.predict(dummyIter, 2)
require(output(0).shape == Shape(2, 10, 1))
require(output(1).shape == Shape(2, 10, 1))
require(output(2).shape == Shape(2, 10, 4))
}

test ("model dtype") {
val dType = DType.Float32
val dShape = Shape(3, 8, 7)
Expand Down

0 comments on commit f714b21

Please sign in to comment.