From 9f97ec40b1caf330721d3639ce2e808e1f4425e1 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 4 Dec 2018 15:48:03 -0800 Subject: [PATCH] fix the CI --- .../org/apache/mxnet/infer/Classifier.scala | 8 ++++-- .../apache/mxnet/infer/ObjectDetector.scala | 25 +++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala index 9c5a1d734699..8351cf938695 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala @@ -24,6 +24,7 @@ import org.slf4j.LoggerFactory import scala.io import scala.collection.mutable.ListBuffer +import scala.collection.parallel.mutable.ParArray trait ClassifierBase { @@ -113,15 +114,18 @@ class Classifier(modelPathPrefix: String, val predictResultND: NDArray = predictor.predictWithNDArray(input)(0).asInContext(Context.cpu()) - val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]() + val predictResultPar: ParArray[Array[Float]] = + new ParArray[Array[Float]](predictResultND.shape(0)) // iterating over the individual items(batch size is in axis 0) (0 until predictResultND.shape(0)).toVector.par.foreach( i => { val r = predictResultND.at(i) - predictResult += r.toArray + predictResultPar(i) = r.toArray r.dispose() }) + val predictResult = predictResultPar.toArray + var result: ListBuffer[IndexedSeq[(String, Float)]] = ListBuffer.empty[IndexedSeq[(String, Float)]] diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala index dc4149188522..244f0286f2fd 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala @@ -19,6 +19,8 @@ package org.apache.mxnet.infer // scalastyle:off import java.awt.image.BufferedImage + +import scala.collection.parallel.mutable.ParArray // scalastyle:on import org.apache.mxnet.NDArray import org.apache.mxnet.DataDesc @@ -95,10 +97,10 @@ class ObjectDetector(modelPathPrefix: String, : IndexedSeq[IndexedSeq[(String, Array[Float])]] = { val predictResult = predictor.predictWithNDArray(input)(0).asInContext(Context.cpu()) - var batchResult = ListBuffer[IndexedSeq[(String, Array[Float])]]() - (0 until predictResult.shape(0)).toVector.par.foreach( i => { + var batchResult = new ParArray[IndexedSeq[(String, Array[Float])]](predictResult.shape(0)) + (0 until predictResult.shape(0)).toArray.par.foreach( i => { val r = predictResult.at(i) - batchResult += sortAndReformat(r, topK) + batchResult(i) = sortAndReformat(r, topK) handler.execute(r.dispose()) }) handler.execute(predictResult.dispose()) @@ -107,26 +109,24 @@ class ObjectDetector(modelPathPrefix: String, private[infer] def sortAndReformat(predictResultND: NDArray, topK: Option[Int]) : IndexedSeq[(String, Array[Float])] = { - val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]() - val accuracy: ListBuffer[Float] = ListBuffer[Float]() - // iterating over the all the predictions val length = predictResultND.shape(0) - (0 until length).toVector.par.foreach( i => { + val predictResult = (0 until length).toArray.par.flatMap( i => { val r = predictResultND.at(i) val tempArr = r.toArray - if (tempArr(0) != -1.0) { - predictResult += tempArr - accuracy += tempArr(1) + val res = if (tempArr(0) != -1.0) { + Array[Array[Float]](tempArr) } else { // Ignore the minus 1 part + Array[Array[Float]]() } handler.execute(r.dispose()) - }) + res + }).toArray var result = IndexedSeq[(String, Array[Float])]() if (topK.isDefined) { - var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2) + var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2) sortedIndices = sortedIndices.take(topK.get) // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax result = sortedIndices.map(idx @@ -136,7 +136,6 @@ class ObjectDetector(modelPathPrefix: String, result = predictResult.map(ele => (synset(ele(0).toInt), ele.takeRight(5))).toIndexedSeq } - result }