diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala index 3c80f9226399..99c0432d79f2 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala @@ -66,6 +66,8 @@ class ImageClassifier(modelPathPrefix: String, protected[infer] val height = inputShape(inputLayout.indexOf('H')) protected[infer] val width = inputShape(inputLayout.indexOf('W')) + def outputShapes: IndexedSeq[(String, Shape)] = predictor.outputShapes + /** * To classify the image according to the provided model * diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala index 7146156d7cc5..e29f068d5558 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala @@ -20,12 +20,13 @@ package org.apache.mxnet.infer // scalastyle:off import java.awt.image.BufferedImage +import org.apache.mxnet.Shape + import scala.collection.parallel.mutable.ParArray // scalastyle:on import org.apache.mxnet.NDArray import org.apache.mxnet.DataDesc import org.apache.mxnet.Context -import scala.collection.mutable.ListBuffer /** * The ObjectDetector class helps to run ObjectDetection tasks where the goal @@ -174,7 +175,25 @@ class ObjectDetector(modelPathPrefix: String, contexts: Array[Context] = Context.cpu(), epoch: Option[Int] = Some(0)): ImageClassifier = { - new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch) - } + val imageClassifier: ImageClassifier = + new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch) + + val shapes: IndexedSeq[(String, Shape)] = imageClassifier.outputShapes + if (shapes.length != inputDescriptors.length) { + throw new IllegalStateException(s"Invalid output shapes, expected:" + + s" $inputDescriptors.length, actual: $shapes.length.") + } + shapes.map(_._2).foreach(shape => { + if (shape.length < 3) { + throw new IllegalArgumentException("Invalid output shapes, the model doesn't" + + " support object detection.") + } + if (shape.get(2) < 6) { + throw new IllegalArgumentException("Invalid output shapes, the model doesn't" + + " support object detection with bounding box.") + } + }) + imageClassifier + } } diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala index 67692a316cc4..66284c81bd2e 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala @@ -56,6 +56,11 @@ private[infer] trait PredictBase { */ def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray] + /** + * Get model output shapes. + * @return model output shapes. + */ + def outputShapes: IndexedSeq[(String, Shape)] } /** @@ -122,6 +127,8 @@ class Predictor(modelPathPrefix: String, protected[infer] val mod = loadModule() + override def outputShapes: IndexedSeq[(String, Shape)] = mod.outputShapes + /** * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors.