Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixes #14181, validate model output shape for ObjectDetector.
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Feb 20, 2019
1 parent f9c436b commit f1eea20
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class ImageClassifier(modelPathPrefix: String,
protected[infer] val height = inputShape(inputLayout.indexOf('H'))
protected[infer] val width = inputShape(inputLayout.indexOf('W'))

def outputShapes: IndexedSeq[(String, Shape)] = {
predictor.outputShapes
}

/**
* To classify the image according to the provided model
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ package org.apache.mxnet.infer
// scalastyle:off
import java.awt.image.BufferedImage

import org.apache.mxnet.Shape

import scala.collection.parallel.mutable.ParArray
// scalastyle:on
import org.apache.mxnet.NDArray
import org.apache.mxnet.DataDesc
import org.apache.mxnet.Context
import scala.collection.mutable.ListBuffer

/**
* The ObjectDetector class helps to run ObjectDetection tasks where the goal
Expand Down Expand Up @@ -174,7 +175,25 @@ class ObjectDetector(modelPathPrefix: String,
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)):
ImageClassifier = {
new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}
val imageClassifier: ImageClassifier =
new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)

val shapes: IndexedSeq[(String, Shape)] = imageClassifier.outputShapes
if (shapes.length != inputDescriptors.length) {
throw new IllegalStateException(s"Invalid output shapes, expected:" +
s" $inputDescriptors.length, actual: $shapes.length.")
}
shapes.map(_._2).foreach(shape => {
if (shape.length < 3) {
throw new IllegalArgumentException("Invalid output shapes, the model doesn't"
+ " support object detection.")
}
if (shape.get(2) < 6) {
throw new IllegalArgumentException("Invalid output shapes, the model doesn't"
+ " support object detection with bounding box.")
}
})

imageClassifier
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ private[infer] trait PredictBase {
*/
def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray]

/**
* Get model output shapes.
* @return model output shapes.
*/
def outputShapes: IndexedSeq[(String, Shape)]
}

/**
Expand Down Expand Up @@ -122,6 +127,10 @@ class Predictor(modelPathPrefix: String,

protected[infer] val mod = loadModule()

override def outputShapes: IndexedSeq[(String, Shape)] = {
mod.outputShapes
}

/**
* Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
Expand Down

0 comments on commit f1eea20

Please sign in to comment.