From b191aee4b7b897c1006243c2d18608e450fee3e6 Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Thu, 24 Jan 2019 11:36:55 -0800 Subject: [PATCH] [MXNET-1293] Adding Iterables instead of List to method signature for infer APIs in Java (#13977) * Added Iterables as input type instead of List in Predictor for Java * Added Iterables to ObjectDetector API * Added tests for Predictor API * Added tests for ObjectDetector --- .../mxnet/infer/javaapi/ObjectDetector.scala | 10 +++---- .../mxnet/infer/javaapi/Predictor.scala | 12 ++++---- .../infer/javaapi/ObjectDetectorTest.java | 25 ++++++++++++++++ .../mxnet/infer/javaapi/PredictorTest.java | 29 +++++++++++++++++-- 4 files changed, 62 insertions(+), 14 deletions(-) diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala index 3014f8d976da..05334e49a356 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala @@ -44,8 +44,8 @@ import scala.language.implicitConversions */ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){ - def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts: - java.util.List[Context], epoch: Int) + def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], contexts: + java.lang.Iterable[Context], epoch: Int) = this { val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator) .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc} @@ -79,7 +79,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj * @return List of list of tuples of * (class, [probability, xmin, ymin, xmax, ymax]) */ - def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int): + def objectDetectWithNDArray(input: java.lang.Iterable[NDArray], topK: Int): java.util.List[java.util.List[ObjectDetectorOutput]] = { val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK)) (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava @@ -92,7 +92,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj * @param topK Number of result elements to return, sorted by probability * @return List of list of tuples of (class, probability) */ - def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int): + def imageBatchObjectDetect(inputBatch: java.lang.Iterable[BufferedImage], topK: Int): java.util.List[java.util.List[ObjectDetectorOutput]] = { val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK)) (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava @@ -122,7 +122,7 @@ object ObjectDetector { org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape) } - def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = { + def loadInputBatch(inputImagePaths: java.lang.Iterable[String]): java.util.List[BufferedImage] = { org.apache.mxnet.infer.ImageClassifier .loadInputBatch(inputImagePaths.asScala.toList).toList.asJava } diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala index 146fe93105e4..6c0871fae51b 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala @@ -40,8 +40,8 @@ import scala.collection.JavaConverters._ // JavaDoc description of class to be updated in https://issues.apache.org/jira/browse/MXNET-1178 class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor){ - def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], - contexts: java.util.List[Context], epoch: Int) + def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], + contexts: java.lang.Iterable[Context], epoch: Int) = this { val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator) .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc} @@ -97,10 +97,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor) } /** - * Takes input as List of one dimensional arrays and creates the NDArray needed for inference + * Takes input as List of one dimensional iterables and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors. * - * @param input: A List of a one-dimensional array. + * @param input: A List of a one-dimensional iterables of DType Float. An extra List is needed for when the model has more than one input. * @return Indexed sequence array of outputs */ @@ -118,10 +118,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor) * This method is useful when the input is a batch of data * Note: User is responsible for managing allocation/deallocation of input/output NDArrays. * - * @param input List of NDArrays + * @param input Iterable of NDArrays * @return Output of predictions as NDArrays */ - def predictWithNDArray(input: java.util.List[NDArray]): + def predictWithNDArray(input: java.lang.Iterable[NDArray]): java.util.List[NDArray] = { val ret = predictor.predictWithNDArray(convert(JavaConverters .asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq)) diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java index a5e64911d141..3219b5aac8e1 100644 --- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java @@ -29,7 +29,9 @@ import java.awt.image.BufferedImage; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; public class ObjectDetectorTest { @@ -92,6 +94,17 @@ public void testObjectDetectorWithBatchImage() { Assert.assertEquals(expectedResult, actualResult); } + @Test + public void testObjectDetectorWithIterableOfBatchImage() { + + Set batchImage = new HashSet<>(); + batchImage.add(inputImage); + Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult); + List> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK); + Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK); + Assert.assertEquals(expectedResult, actualResult); + } + @Test public void testObjectDetectorWithNDArrayInput() { @@ -103,4 +116,16 @@ public void testObjectDetectorWithNDArrayInput() { Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK); Assert.assertEquals(expectedResult, actualResult); } + + @Test + public void testObjectDetectorWithIterableOfNDArrayInput() { + + NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape()); + Set inputL = new HashSet<>(); + inputL.add(inputArr); + Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult); + List> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK); + Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK); + Assert.assertEquals(expectedResult, actualResult); + } } diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java index e7a6c9652346..0d83c74fe901 100644 --- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java @@ -25,9 +25,7 @@ import org.junit.Test; import org.mockito.Mockito; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; public class PredictorTest { @@ -80,6 +78,31 @@ public void testPredictWithNDArray() { Assert.assertEquals(expectedResult, actualOutput); } + @Test + public void testPredictWithIterablesNDArray() { + + float[] tmpArr = new float[224]; + for (int y = 0; y < 224; y++) + tmpArr[y] = (int) (Math.random() * 10); + + NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0)); + + Set inputSet = new HashSet<>(); + inputSet.add(arr); + + NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0)); + List expectedResult = new ArrayList<>(); + expectedResult.add(expected); + + Mockito.when(mockPredictor.predictWithNDArray(inputSet)).thenReturn(expectedResult); + + List actualOutput = mockPredictor.predictWithNDArray(inputSet); + + Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputSet); + + Assert.assertEquals(expectedResult, actualOutput); + } + @Test public void testPredictWithListOfFloatsAsInput() { List> input = new ArrayList<>();