diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala index adeb33d34a95..cf55bc10d97e 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala @@ -24,6 +24,7 @@ import org.slf4j.LoggerFactory import scala.io import scala.collection.mutable.ListBuffer +import scala.collection.parallel.mutable.ParArray trait ClassifierBase { @@ -110,16 +111,21 @@ class Classifier(modelPathPrefix: String, : IndexedSeq[IndexedSeq[(String, Float)]] = { // considering only the first output - val predictResultND: NDArray = predictor.predictWithNDArray(input)(0) - - val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]() + // Copy NDArray to CPU to avoid frequent GPU to CPU copying + val predictResultND: NDArray = + predictor.predictWithNDArray(input)(0).asInContext(Context.cpu()) + // Parallel Execution with ParArray for better performance + val predictResultPar: ParArray[Array[Float]] = + new ParArray[Array[Float]](predictResultND.shape(0)) // iterating over the individual items(batch size is in axis 0) - for (i <- 0 until predictResultND.shape(0)) { + (0 until predictResultND.shape(0)).toVector.par.foreach( i => { val r = predictResultND.at(i) - predictResult += r.toArray + predictResultPar(i) = r.toArray r.dispose() - } + }) + + val predictResult = predictResultPar.toArray var result: ListBuffer[IndexedSeq[(String, Float)]] = ListBuffer.empty[IndexedSeq[(String, Float)]] diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala index a9b21f8c1dcd..78b237a4a9c6 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala @@ -19,6 +19,8 @@ package org.apache.mxnet.infer // scalastyle:off import java.awt.image.BufferedImage + +import scala.collection.parallel.mutable.ParArray // scalastyle:on import org.apache.mxnet.NDArray import org.apache.mxnet.DataDesc @@ -94,39 +96,39 @@ class ObjectDetector(modelPathPrefix: String, def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int]) : IndexedSeq[IndexedSeq[(String, Array[Float])]] = { - val predictResult = predictor.predictWithNDArray(input)(0) - var batchResult = ListBuffer[IndexedSeq[(String, Array[Float])]]() - for (i <- 0 until predictResult.shape(0)) { + // Copy NDArray to CPU to avoid frequent GPU to CPU copying + val predictResult = predictor.predictWithNDArray(input)(0).asInContext(Context.cpu()) + // Parallel Execution with ParArray for better performance + var batchResult = new ParArray[IndexedSeq[(String, Array[Float])]](predictResult.shape(0)) + (0 until predictResult.shape(0)).toArray.par.foreach( i => { val r = predictResult.at(i) - batchResult += sortAndReformat(r, topK) + batchResult(i) = sortAndReformat(r, topK) handler.execute(r.dispose()) - } + }) handler.execute(predictResult.dispose()) batchResult.toIndexedSeq } private[infer] def sortAndReformat(predictResultND: NDArray, topK: Option[Int]) : IndexedSeq[(String, Array[Float])] = { - val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]() - val accuracy: ListBuffer[Float] = ListBuffer[Float]() - // iterating over the all the predictions val length = predictResultND.shape(0) - for (i <- 0 until length) { + val predictResult = (0 until length).toArray.par.flatMap( i => { val r = predictResultND.at(i) val tempArr = r.toArray - if (tempArr(0) != -1.0) { - predictResult += tempArr - accuracy += tempArr(1) + val res = if (tempArr(0) != -1.0) { + Array[Array[Float]](tempArr) } else { // Ignore the minus 1 part + Array[Array[Float]]() } handler.execute(r.dispose()) - } + res + }).toArray var result = IndexedSeq[(String, Array[Float])]() if (topK.isDefined) { - var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2) + var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2) sortedIndices = sortedIndices.take(topK.get) // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax result = sortedIndices.map(idx @@ -136,7 +138,6 @@ class ObjectDetector(modelPathPrefix: String, result = predictResult.map(ele => (synset(ele(0).toInt), ele.takeRight(5))).toIndexedSeq } - result }