Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Scala/Java Predict API fix #14756 #14804

Merged
merged 3 commits into from
Apr 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,23 @@ abstract class BaseModule {

/**
* Run prediction and collect the outputs.
* @param evalData
* @param evalData dataIter to do the Inference
* @param numBatch Default is -1, indicating running all the batches in the data iterator.
* @param reset Default is `True`, indicating whether we should reset the data iter before start
* doing prediction.
* @return The return value will be a list `[out1, out2, out3]`.
* The concatenation process will be like
* {{{
* outputBatches = [
* [a1, a2, a3], // batch a
* [b1, b2, b3] // batch b
* ]
* result = [
* NDArray, // [a1, b1]
* NDArray, // [a2, b2]
* NDArray, // [a3, b3]
* ]
* }}}
* Where each element is concatenation of the outputs for all the mini-batches.
*/
def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
Expand All @@ -264,7 +276,8 @@ abstract class BaseModule {
s"in mini-batches (${out.size})." +
"Maybe bucketing is used?")
)
val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
val oBT = outputBatches.transpose
val concatenatedOutput = oBT.map(out => NDArray.concatenate(out))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you also need to dispose oBT here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This create a soft copy, the content inside contains same reference

outputBatches.foreach(_.foreach(_.dispose()))
concatenatedOutput
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._
import org.apache.mxnet.io._

class ModuleSuite extends FunSuite with BeforeAndAfterAll {

class myModule(symbol : Symbol) extends Module (symbol) {
override def predictEveryBatch(evalData: DataIter,
numBatch: Int = 1, reset: Boolean = true):
IndexedSeq[IndexedSeq[NDArray]] = {
val data = IndexedSeq(
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 4))
)
List.fill(numBatch)(data).toIndexedSeq
}
}

test("predict") {
val sym = Symbol.Variable("data")
val mod = new myModule(sym)
val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1)))
var output = mod.predict(dummyIter, 1)
require(output(0).shape == Shape(1, 10, 1))
require(output(1).shape == Shape(1, 10, 1))
require(output(2).shape == Shape(1, 10, 4))
output = mod.predict(dummyIter, 2)
require(output(0).shape == Shape(2, 10, 1))
require(output(1).shape == Shape(2, 10, 1))
require(output(2).shape == Shape(2, 10, 4))
}

test ("model dtype") {
val dType = DType.Float32
val dShape = Shape(3, 8, 7)
Expand Down