Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Reduce post processing time
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Dec 4, 2018
1 parent f2dcd7c commit 02c755b
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ class ObjectDetector(modelPathPrefix: String,
def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int])
: IndexedSeq[IndexedSeq[(String, Array[Float])]] = {

val predictResult = predictor.predictWithNDArray(input)(0)
val predictResult = predictor.predictWithNDArray(input)(0).asInContext(Context.cpu())
var batchResult = ListBuffer[IndexedSeq[(String, Array[Float])]]()
for (i <- 0 until predictResult.shape(0)) {
(0 until predictResult.shape(0)).toVector.par.foreach( i => {
val r = predictResult.at(i)
batchResult += sortAndReformat(r, topK)
handler.execute(r.dispose())
}
})
handler.execute(predictResult.dispose())
batchResult.toIndexedSeq
}
Expand All @@ -113,7 +113,7 @@ class ObjectDetector(modelPathPrefix: String,
// iterating over the all the predictions
val length = predictResultND.shape(0)

for (i <- 0 until length) {
(0 until length).toVector.par.foreach( i => {
val r = predictResultND.at(i)
val tempArr = r.toArray
if (tempArr(0) != -1.0) {
Expand All @@ -123,7 +123,7 @@ class ObjectDetector(modelPathPrefix: String,
// Ignore the minus 1 part
}
handler.execute(r.dispose())
}
})
var result = IndexedSeq[(String, Array[Float])]()
if (topK.isDefined) {
var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2)
Expand Down

0 comments on commit 02c755b

Please sign in to comment.