From f1eea20f43682c5b7e4aab628e846d358b225521 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 20 Feb 2019 14:37:51 -0800 Subject: [PATCH] Fixes #14181, validate model output shape for ObjectDetector. --- .../apache/mxnet/infer/ImageClassifier.scala | 4 +++ .../apache/mxnet/infer/ObjectDetector.scala | 25 ++++++++++++++++--- .../org/apache/mxnet/infer/Predictor.scala | 9 +++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala index 3c80f9226399..7c72b90dde74 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala @@ -66,6 +66,10 @@ class ImageClassifier(modelPathPrefix: String, protected[infer] val height = inputShape(inputLayout.indexOf('H')) protected[infer] val width = inputShape(inputLayout.indexOf('W')) + def outputShapes: IndexedSeq[(String, Shape)] = { + predictor.outputShapes + } + /** * To classify the image according to the provided model * diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala index 7146156d7cc5..e29f068d5558 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala @@ -20,12 +20,13 @@ package org.apache.mxnet.infer // scalastyle:off import java.awt.image.BufferedImage +import org.apache.mxnet.Shape + import scala.collection.parallel.mutable.ParArray // scalastyle:on import org.apache.mxnet.NDArray import org.apache.mxnet.DataDesc import org.apache.mxnet.Context -import scala.collection.mutable.ListBuffer /** * The ObjectDetector class helps to run ObjectDetection tasks where the goal @@ -174,7 +175,25 @@ class ObjectDetector(modelPathPrefix: String, contexts: Array[Context] = Context.cpu(), epoch: Option[Int] = Some(0)): ImageClassifier = { - new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch) - } + val imageClassifier: ImageClassifier = + new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch) + + val shapes: IndexedSeq[(String, Shape)] = imageClassifier.outputShapes + if (shapes.length != inputDescriptors.length) { + throw new IllegalStateException(s"Invalid output shapes, expected:" + + s" $inputDescriptors.length, actual: $shapes.length.") + } + shapes.map(_._2).foreach(shape => { + if (shape.length < 3) { + throw new IllegalArgumentException("Invalid output shapes, the model doesn't" + + " support object detection.") + } + if (shape.get(2) < 6) { + throw new IllegalArgumentException("Invalid output shapes, the model doesn't" + + " support object detection with bounding box.") + } + }) + imageClassifier + } } diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala index 67692a316cc4..b3c891d97b32 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala @@ -56,6 +56,11 @@ private[infer] trait PredictBase { */ def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray] + /** + * Get model output shapes. + * @return model output shapes. + */ + def outputShapes: IndexedSeq[(String, Shape)] } /** @@ -122,6 +127,10 @@ class Predictor(modelPathPrefix: String, protected[infer] val mod = loadModule() + override def outputShapes: IndexedSeq[(String, Shape)] = { + mod.outputShapes + } + /** * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors.