Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added Float64 in Classifier class
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushghai committed Dec 21, 2018
1 parent d1014b3 commit e564656
Showing 1 changed file with 39 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.mxnet.infer

import org.apache.mxnet.{Context, DataDesc, NDArray}
import org.apache.mxnet.{Base, Context, DataDesc, NDArray}
import java.io.File

import org.slf4j.LoggerFactory
Expand All @@ -35,8 +35,8 @@ trait ClassifierBase {
* elements to return. Default returns unsorted output.
* @return Indexed sequence of (Label, Score) tuples
*/
def classify(input: IndexedSeq[Array[Float]],
topK: Option[Int] = None): IndexedSeq[(String, Float)]
def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
topK: Option[Int] = None): IndexedSeq[(String, T)]

/**
* Takes a sequence of NDArrays and returns (Label, Score) tuples
Expand Down Expand Up @@ -83,11 +83,28 @@ class Classifier(modelPathPrefix: String,
* elements to return. Default returns unsorted output.
* @return Indexed sequence of (Label, Score) tuples
*/
override def classify(input: IndexedSeq[Array[Float]],
topK: Option[Int] = None): IndexedSeq[(String, Float)] = {
override def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
topK: Option[Int] = None): IndexedSeq[(String, T)] = {

// considering only the first output
val result = input(0)(0) match {
case d: Double => {
classifyWithDoubleImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK)
}
case _ => {
classifyWithFloatImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK)
}
}

result.asInstanceOf[IndexedSeq[(String, T)]]
}

private def classifyWithFloatImpl(input: IndexedSeq[Array[Float]], topK: Option[Int] = None)
: IndexedSeq[(String, Float)] = {

// considering only the first output
val predictResult = predictor.predict(input)(0)

var result: IndexedSeq[(String, Float)] = IndexedSeq.empty

if (topK.isDefined) {
Expand All @@ -99,6 +116,23 @@ class Classifier(modelPathPrefix: String,
result
}

private def classifyWithDoubleImpl(input: IndexedSeq[Array[Double]], topK: Option[Int] = None)
: IndexedSeq[(String, Double)] = {

// considering only the first output
val predictResult = predictor.predict(input)(0)

var result: IndexedSeq[(String, Double)] = IndexedSeq.empty

if (topK.isDefined) {
val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
result = sortedIndex.map(i => (synset(i), predictResult(i))).toIndexedSeq
} else {
result = synset.zip(predictResult).toIndexedSeq
}
result
}

/**
* Perform multiple classification operations on NDArrays.
* Also works with batched input.
Expand Down

0 comments on commit e564656

Please sign in to comment.